aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Dialect/GPU/Transforms/Passes.h14
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td6
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td46
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.h3
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h9
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td40
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td2
-rw-r--r--mlir/include/mlir/Dialect/Utils/IndexingUtils.h3
-rw-r--r--mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h43
-rw-r--r--mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h24
-rw-r--r--mlir/include/mlir/Interfaces/FunctionInterfaces.td226
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h95
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp18
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp24
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp28
-rw-r--r--mlir/lib/Dialect/GPU/CMakeLists.txt52
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp11
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp180
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp5
-rw-r--r--mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp10
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp45
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp15
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp2
-rw-r--r--mlir/lib/Dialect/Utils/IndexingUtils.cpp11
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp27
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp41
-rw-r--r--mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp7
-rw-r--r--mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp35
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp1574
-rw-r--r--mlir/test/CAPI/llvm.c3
-rw-r--r--mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir81
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir15
-rw-r--r--mlir/test/Dialect/Linalg/canonicalize.mlir55
-rw-r--r--mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir40
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir10
-rw-r--r--mlir/test/Dialect/Mesh/spmdization.mlir14
-rw-r--r--mlir/test/Dialect/Tosa/ops.mlir14
-rw-r--r--mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir82
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-flatten.mlir38
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir9
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir9
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir8
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir7
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir8
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir10
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir8
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir7
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir6
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir5
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir25
-rw-r--r--mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir24
-rw-r--r--mlir/test/Target/LLVMIR/Import/import-failure.ll40
-rw-r--r--mlir/test/lib/Dialect/ArmSME/CMakeLists.txt18
-rw-r--r--mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp99
-rw-r--r--mlir/test/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp14
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
-rw-r--r--mlir/unittests/Dialect/SPIRV/SerializationTest.cpp6
-rw-r--r--mlir/unittests/IR/InterfaceAttachmentTest.cpp2
-rw-r--r--mlir/unittests/IR/OperationSupportTest.cpp4
-rw-r--r--mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp8
73 files changed, 1685 insertions, 1692 deletions
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 2d9f78e..16c898b 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -123,7 +123,6 @@ else()
endif()
add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ENABLE_ROCM_CONVERSIONS})
-set(MLIR_ENABLE_DEPRECATED_GPU_SERIALIZATION 0 CACHE BOOL "Enable deprecated GPU serialization passes")
set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the mlir CUDA runner")
set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the mlir ROCm runner")
set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the mlir Sycl runner")
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 5885fac..8f7466a 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -147,25 +147,11 @@ protected:
// Registration
//===----------------------------------------------------------------------===//
-/// Register pass to serialize GPU kernel functions to a CUBIN binary
-/// annotation.
-LLVM_DEPRECATED("use Target attributes instead", "")
-void registerGpuSerializeToCubinPass();
-
/// Register pass to serialize GPU kernel functions to a HSAco binary
/// annotation.
LLVM_DEPRECATED("use Target attributes instead", "")
void registerGpuSerializeToHsacoPass();
-/// Create an instance of the GPU kernel function to CUBIN binary serialization
-/// pass with optLevel (default level 2).
-LLVM_DEPRECATED("use Target attributes instead", "")
-std::unique_ptr<Pass> createGpuSerializeToCubinPass(StringRef triple,
- StringRef chip,
- StringRef features,
- int optLevel = 2,
- bool dumpPtx = false);
-
/// Create an instance of the GPU kernel function to HSAco binary serialization
/// pass.
LLVM_DEPRECATED("use Target attributes instead", "")
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index feb3578..b88f118 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -513,7 +513,11 @@ def LLVM_DbgLabelOp : LLVM_IntrOp<"dbg.label", [], [], [], 0> {
});
}];
let mlirBuilder = [{
- $_op = $_builder.create<$_qualCppClassName>($_location, $_label_attr($label));
+ DILabelAttr labelAttr = $_label_attr($label);
+ // Drop the intrinsic if the label translation fails due to cylic metadata.
+ if (!labelAttr)
+ return success();
+ $_op = $_builder.create<$_qualCppClassName>($_location, labelAttr);
}];
let assemblyFormat = "$label attr-dict";
}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index d9b130b..3da5dee 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -49,7 +49,7 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
}
class LLVM_IntArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
- LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName, traits> {
+ LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName, traits> {
let arguments = commonArgs;
string mlirBuilder = [{
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
@@ -57,7 +57,7 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
}
class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
- LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+ LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
@@ -143,9 +143,9 @@ class LLVM_ArithmeticCmpOp<string mnemonic, list<Trait> traits = []> :
// Other integer operations.
def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> {
let arguments = (ins ICmpPredicate:$predicate,
- AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>,
+ AnyTypeOf<[LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
- AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>,
+ AnyTypeOf<[LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
let hasCustomAssemblyFormat = 1;
string llvmInstName = "ICmp";
@@ -204,7 +204,7 @@ def LLVM_AllocaOp : LLVM_Op<"alloca",
DeclareOpInterfaceMethods<DestructurableAllocationOpInterface>,
DeclareOpInterfaceMethods<GetResultPtrElementType>]>,
LLVM_MemOpPatterns {
- let arguments = (ins AnyInteger:$arraySize,
+ let arguments = (ins AnySignlessInteger:$arraySize,
OptionalAttr<I64Attr>:$alignment,
TypeAttr:$elem_type,
UnitAttr:$inalloca);
@@ -250,7 +250,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
DeclareOpInterfaceMethods<GetResultPtrElementType>]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
- Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$dynamicIndices,
+ Variadic<LLVM_ScalarOrVectorOf<AnySignlessInteger>>:$dynamicIndices,
DenseI32ArrayAttr:$rawConstantIndices,
TypeAttr:$elem_type,
UnitAttr:$inbounds);
@@ -499,37 +499,37 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast",
let hasFolder = 1;
}
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr",
- LLVM_ScalarOrVectorOf<AnyInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
- LLVM_ScalarOrVectorOf<AnyInteger>,
- LLVM_ScalarOrVectorOf<AnyInteger>> {
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasVerifier = 1;
}
def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
- LLVM_ScalarOrVectorOf<AnyInteger>,
- LLVM_ScalarOrVectorOf<AnyInteger>> {
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasFolder = 1;
let hasVerifier = 1;
}
def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
- LLVM_ScalarOrVectorOf<AnyInteger>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
- LLVM_ScalarOrVectorOf<AnyInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP",
- LLVM_ScalarOrVectorOf<AnyInteger>,
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "FPToUI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
- LLVM_ScalarOrVectorOf<AnyInteger>>;
+ LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
def LLVM_FPExtOp : LLVM_CastOp<"fpext", "FPExt",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
@@ -671,7 +671,7 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
"LLVM::getVectorElementType($_self)">]> {
let summary = "Extract an element from an LLVM vector.";
- let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
+ let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
let results = (outs LLVM_Type:$res);
let assemblyFormat = [{
@@ -733,7 +733,7 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
let summary = "Insert an element into an LLVM vector.";
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
- AnyInteger:$position);
+ AnySignlessInteger:$position);
let results = (outs LLVM_AnyVector:$res);
let builders = [LLVM_OneResultOpBuilder];
@@ -971,7 +971,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
let arguments = (ins
- AnyInteger:$value,
+ AnySignlessInteger:$value,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
@@ -1647,7 +1647,7 @@ def LLVM_ConstantOp
// Atomic operations.
//
-def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnyInteger]>;
+def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
@@ -1696,7 +1696,7 @@ def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
let hasVerifier = 1;
}
-def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
+def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnySignlessInteger, LLVM_AnyPointer]>;
def LLVM_AtomicCmpXchgOp : LLVM_MemAccessOpBase<"cmpxchg", [
TypesMatchWith<"operand #1 and operand #2 have the same type",
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 010dde5..11b2c7a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -45,6 +45,9 @@ struct MathPolynomialApproximationOptions {
bool enableAvx2 = false;
};
+void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
+void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+
void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a00c9c3..1c81d80 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -145,12 +145,9 @@ constexpr bool isComplexPrimaryType(PrimaryType valTy) {
/// The actions performed by @newSparseTensor.
enum class Action : uint32_t {
kEmpty = 0,
- kEmptyForward = 1,
- kFromCOO = 2,
- kFromReader = 4,
- kToCOO = 5,
- kPack = 7,
- kSortCOOInPlace = 8,
+ kFromReader = 1,
+ kPack = 2,
+ kSortCOOInPlace = 3,
};
/// This enum defines all supported storage format without the level properties.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 0ee9e71..0ecded7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -990,6 +990,26 @@ def Tosa_ClzOp : Tosa_ElementwiseOp<"clz", [SameOperandsAndResultElementType]> {
}
//===----------------------------------------------------------------------===//
+// Operator: cos
+//===----------------------------------------------------------------------===//
+def Tosa_CosOp : Tosa_ElementwiseOp<"cos",
+ [SameOperandsAndResultElementType]> {
+ let summary = "Elementwise cos op";
+
+ let description = [{
+ Elementwise cosine operation for values given in radians.
+ }];
+
+ let arguments = (ins
+ Tosa_FloatTensor:$input
+ );
+
+ let results = (outs
+ Tosa_FloatTensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
// Operator: exp
//===----------------------------------------------------------------------===//
def Tosa_ExpOp : Tosa_ElementwiseOp<"exp", [SameOperandsAndResultElementType]> {
@@ -1149,6 +1169,26 @@ def Tosa_RsqrtOp : Tosa_ElementwiseOp<"rsqrt",
}
//===----------------------------------------------------------------------===//
+// Operator: sin
+//===----------------------------------------------------------------------===//
+def Tosa_SinOp : Tosa_ElementwiseOp<"sin",
+ [SameOperandsAndResultElementType]> {
+ let summary = "Elementwise sin op";
+
+ let description = [{
+ Elementwise sine operation for values given in radians.
+ }];
+
+ let arguments = (ins
+ Tosa_FloatTensor:$input
+ );
+
+ let results = (outs
+ Tosa_FloatTensor:$output
+ );
+}
+
+//===----------------------------------------------------------------------===//
// TOSA Spec Section 2.6
// Operator Class: Elementwise unary/binary/ternary operators.
// Operator Subclass: Elementwise ternary ops.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index c55ddaa..5a4d6ff 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -113,6 +113,8 @@ def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
+def Tosa_FloatTensor : TensorOf<[Tosa_Float]>;
+
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 2453d84..9892253 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -257,6 +257,9 @@ SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
std::pair<AffineExpr, SmallVector<OpFoldResult>>
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> indices);
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+ ArrayRef<Value> indices);
//===----------------------------------------------------------------------===//
// Utilities for decomposing larger shapes
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index eff1aca..fe0e08b 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -149,13 +149,6 @@ public:
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
#undef DECL_GETVALUES
- /// Element-wise forwarding insertions. The first argument is the
- /// dimension-coordinates for the value being inserted.
-#define DECL_FORWARDINGINSERT(VNAME, V) \
- virtual void forwardingInsert(const uint64_t *, V);
- MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
-#undef DECL_FORWARDINGINSERT
-
/// Element-wise insertion in lexicographic coordinate order. The first
/// argument is the level-coordinates for the value being inserted.
#define DECL_LEXINSERT(VNAME, V) virtual void lexInsert(const uint64_t *, V);
@@ -171,9 +164,6 @@ public:
MLIR_SPARSETENSOR_FOREVERY_V(DECL_EXPINSERT)
#undef DECL_EXPINSERT
- /// Finalizes forwarding insertions.
- virtual void endForwardingInsert() = 0;
-
/// Finalizes lexicographic insertions.
virtual void endLexInsert() = 0;
@@ -248,7 +238,7 @@ public:
static SparseTensorStorage<P, C, V> *
newEmpty(uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
const uint64_t *lvlSizes, const LevelType *lvlTypes,
- const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding);
+ const uint64_t *dim2lvl, const uint64_t *lvl2dim);
/// Allocates a new sparse tensor and initializes it from the given COO.
static SparseTensorStorage<P, C, V> *
@@ -284,13 +274,6 @@ public:
*out = &values;
}
- /// Partially specialize forwarding insertions based on template types.
- void forwardingInsert(const uint64_t *dimCoords, V val) final {
- assert(dimCoords && coo);
- map.pushforward(dimCoords, lvlCursor.data());
- coo->add(lvlCursor, val);
- }
-
/// Partially specialize lexicographical insertions based on template types.
void lexInsert(const uint64_t *lvlCoords, V val) final {
assert(lvlCoords);
@@ -345,21 +328,6 @@ public:
}
}
- /// Finalizes forwarding insertions.
- void endForwardingInsert() final {
- // Ensure COO is sorted.
- assert(coo);
- coo->sort();
- // Now actually insert the `elements`.
- const auto &elements = coo->getElements();
- const uint64_t nse = elements.size();
- assert(values.size() == 0);
- values.reserve(nse);
- fromCOO(elements, 0, nse, 0);
- delete coo;
- coo = nullptr;
- }
-
/// Finalizes lexicographic insertions.
void endLexInsert() final {
if (!allDense) {
@@ -653,13 +621,10 @@ template <typename P, typename C, typename V>
SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::newEmpty(
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
const uint64_t *lvlSizes, const LevelType *lvlTypes,
- const uint64_t *dim2lvl, const uint64_t *lvl2dim, bool forwarding) {
- SparseTensorCOO<V> *lvlCOO = nullptr;
- if (forwarding)
- lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes);
+ const uint64_t *dim2lvl, const uint64_t *lvl2dim) {
return new SparseTensorStorage<P, C, V>(dimRank, dimSizes, lvlRank, lvlSizes,
- lvlTypes, dim2lvl, lvl2dim, lvlCOO,
- !forwarding);
+ lvlTypes, dim2lvl, lvl2dim, nullptr,
+ true);
}
template <typename P, typename C, typename V>
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index 8b0829a..d916186 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -38,15 +38,12 @@ extern "C" {
/// This is the "swiss army knife" method for materializing sparse
/// tensors into the computation. The types of the `ptr` argument and
/// the result depend on the action, as explained in the following table,
-/// where "STS" means a sparse-tensor-storage object and "COO" means
-/// a coordinate-scheme object.
+/// where "STS" means a sparse-tensor-storage object.
///
/// Action: `ptr`: Returns:
+/// ---------------------------------------------------------------------------
/// kEmpty - STS, empty
-/// kEmptyForward - STS, empty, with forwarding COO
-/// kFromCOO COO STS, copied from the COO source
/// kFromReader reader STS, input from reader
-/// kToCOO STS COO, copied from the STS source
/// kPack buffers STS, from level buffers
/// kSortCOOInPlace STS STS, sorted in place
MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensor( // NOLINT
@@ -80,14 +77,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES
-/// Tensor-storage method for a dim to lvl forwarding insertion.
-#define DECL_FORWARDINGINSERT(VNAME, V) \
- MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_forwardingInsert##VNAME( \
- void *tensor, StridedMemRefType<V, 0> *vref, \
- StridedMemRefType<index_type, 1> *dimCoordsRef); \
- MLIR_SPARSETENSOR_FOREVERY_V(DECL_FORWARDINGINSERT)
-#undef DECL_FORWARDINGINSERT
-
/// Tensor-storage method to insert elements in lexicographical
/// level-coordinate order.
#define DECL_LEXINSERT(VNAME, V) \
@@ -160,21 +149,12 @@ MLIR_CRUNNERUTILS_EXPORT index_type sparseLvlSize(void *tensor, index_type l);
/// Tensor-storage method to get the size of the given dimension.
MLIR_CRUNNERUTILS_EXPORT index_type sparseDimSize(void *tensor, index_type d);
-/// Tensor-storage method to finalize forwarding insertions.
-MLIR_CRUNNERUTILS_EXPORT void endForwardingInsert(void *tensor);
-
/// Tensor-storage method to finalize lexicographic insertions.
MLIR_CRUNNERUTILS_EXPORT void endLexInsert(void *tensor);
/// Releases the memory for the tensor-storage object.
MLIR_CRUNNERUTILS_EXPORT void delSparseTensor(void *tensor);
-/// Releases the memory for the coordinate-scheme object.
-#define DECL_DELCOO(VNAME, V) \
- MLIR_CRUNNERUTILS_EXPORT void delSparseTensorCOO##VNAME(void *coo);
-MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
-#undef DECL_DELCOO
-
/// Helper function to read a sparse tensor filename from the environment,
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
diff --git a/mlir/include/mlir/Interfaces/FunctionInterfaces.td b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
index 98e0025..970a781 100644
--- a/mlir/include/mlir/Interfaces/FunctionInterfaces.td
+++ b/mlir/include/mlir/Interfaces/FunctionInterfaces.td
@@ -147,12 +147,12 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
}];
let extraSharedClassDeclaration = [{
/// Block list iterator types.
- using BlockListType = Region::BlockListType;
+ using BlockListType = ::mlir::Region::BlockListType;
using iterator = BlockListType::iterator;
using reverse_iterator = BlockListType::reverse_iterator;
/// Block argument iterator types.
- using BlockArgListType = Region::BlockArgListType;
+ using BlockArgListType = ::mlir::Region::BlockArgListType;
using args_iterator = BlockArgListType::iterator;
//===------------------------------------------------------------------===//
@@ -163,7 +163,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
bool isExternal() { return empty(); }
/// Return the region containing the body of this function.
- Region &getFunctionBody() { return $_op->getRegion(0); }
+ ::mlir::Region &getFunctionBody() { return $_op->getRegion(0); }
/// Delete all blocks from this function.
void eraseBody() {
@@ -183,39 +183,39 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
bool empty() { return getFunctionBody().empty(); }
/// Push a new block to the back of the body region.
- void push_back(Block *block) { getFunctionBody().push_back(block); }
+ void push_back(::mlir::Block *block) { getFunctionBody().push_back(block); }
/// Push a new block to the front of the body region.
- void push_front(Block *block) { getFunctionBody().push_front(block); }
+ void push_front(::mlir::Block *block) { getFunctionBody().push_front(block); }
/// Return the last block in the body region.
- Block &back() { return getFunctionBody().back(); }
+ ::mlir::Block &back() { return getFunctionBody().back(); }
/// Return the first block in the body region.
- Block &front() { return getFunctionBody().front(); }
+ ::mlir::Block &front() { return getFunctionBody().front(); }
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function. The newly inserted entry block
/// is returned.
- Block *addEntryBlock() {
+ ::mlir::Block *addEntryBlock() {
assert(empty() && "function already has an entry block");
- Block *entry = new Block();
+ ::mlir::Block *entry = new ::mlir::Block();
push_back(entry);
// FIXME: Allow for passing in locations for these arguments instead of using
// the operations location.
- ArrayRef<Type> inputTypes = $_op.getArgumentTypes();
- SmallVector<Location> locations(inputTypes.size(),
- $_op.getOperation()->getLoc());
+ ::llvm::ArrayRef<::mlir::Type> inputTypes = $_op.getArgumentTypes();
+ ::llvm::SmallVector<::mlir::Location> locations(inputTypes.size(),
+ $_op.getOperation()->getLoc());
entry->addArguments(inputTypes, locations);
return entry;
}
/// Add a normal block to the end of the function's block list. The function
/// should at least already have an entry block.
- Block *addBlock() {
+ ::mlir::Block *addBlock() {
assert(!empty() && "function should at least have an entry block");
- push_back(new Block());
+ push_back(new ::mlir::Block());
return &back();
}
@@ -230,8 +230,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// - the argument/result attributes may need an update: if the new type
/// has less parameters we drop the extra attributes, if there are more
/// parameters they won't have any attributes.
- void setType(Type newType) {
- function_interface_impl::setFunctionType($_op, newType);
+ void setType(::mlir::Type newType) {
+ ::mlir::function_interface_impl::setFunctionType($_op, newType);
}
//===------------------------------------------------------------------===//
@@ -245,7 +245,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
unsigned getNumResults() { return $_op.getResultTypes().size(); }
/// Returns the entry block function argument at the given index.
- BlockArgument getArgument(unsigned idx) {
+ ::mlir::BlockArgument getArgument(unsigned idx) {
return getFunctionBody().getArgument(idx);
}
@@ -256,8 +256,8 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// Insert a single argument of type `argType` with attributes `argAttrs` and
/// location `argLoc` at `argIndex`.
- void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs,
- Location argLoc) {
+ void insertArgument(unsigned argIndex, ::mlir::Type argType, ::mlir::DictionaryAttr argAttrs,
+ ::mlir::Location argLoc) {
insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc});
}
@@ -265,20 +265,20 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// listed indices. `argIndices` must be sorted. Arguments are inserted in the
/// order they are listed, such that arguments with identical index will
/// appear in the same order that they were listed here.
- void insertArguments(ArrayRef<unsigned> argIndices, TypeRange argTypes,
- ArrayRef<DictionaryAttr> argAttrs,
- ArrayRef<Location> argLocs) {
+ void insertArguments(::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
+ ::llvm::ArrayRef<::mlir::DictionaryAttr> argAttrs,
+ ::llvm::ArrayRef<::mlir::Location> argLocs) {
unsigned originalNumArgs = $_op.getNumArguments();
- Type newType = $_op.getTypeWithArgsAndResults(
+ ::mlir::Type newType = $_op.getTypeWithArgsAndResults(
argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
- function_interface_impl::insertFunctionArguments(
+ ::mlir::function_interface_impl::insertFunctionArguments(
$_op, argIndices, argTypes, argAttrs, argLocs,
originalNumArgs, newType);
}
/// Insert a single result of type `resultType` at `resultIndex`.
- void insertResult(unsigned resultIndex, Type resultType,
- DictionaryAttr resultAttrs) {
+ void insertResult(unsigned resultIndex, ::mlir::Type resultType,
+ ::mlir::DictionaryAttr resultAttrs) {
insertResults({resultIndex}, {resultType}, {resultAttrs});
}
@@ -286,41 +286,41 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// `resultIndices` must be sorted. Results are inserted in the order they are
/// listed, such that results with identical index will appear in the same
/// order that they were listed here.
- void insertResults(ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
- ArrayRef<DictionaryAttr> resultAttrs) {
+ void insertResults(::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes,
+ ::llvm::ArrayRef<::mlir::DictionaryAttr> resultAttrs) {
unsigned originalNumResults = $_op.getNumResults();
- Type newType = $_op.getTypeWithArgsAndResults(
+ ::mlir::Type newType = $_op.getTypeWithArgsAndResults(
/*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
- function_interface_impl::insertFunctionResults(
+ ::mlir::function_interface_impl::insertFunctionResults(
$_op, resultIndices, resultTypes, resultAttrs,
originalNumResults, newType);
}
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) {
- BitVector argsToErase($_op.getNumArguments());
+ ::llvm::BitVector argsToErase($_op.getNumArguments());
argsToErase.set(argIndex);
eraseArguments(argsToErase);
}
/// Erases the arguments listed in `argIndices`.
- void eraseArguments(const BitVector &argIndices) {
- Type newType = $_op.getTypeWithoutArgs(argIndices);
- function_interface_impl::eraseFunctionArguments(
+ void eraseArguments(const ::llvm::BitVector &argIndices) {
+ ::mlir::Type newType = $_op.getTypeWithoutArgs(argIndices);
+ ::mlir::function_interface_impl::eraseFunctionArguments(
$_op, argIndices, newType);
}
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) {
- BitVector resultsToErase($_op.getNumResults());
+ ::llvm::BitVector resultsToErase($_op.getNumResults());
resultsToErase.set(resultIndex);
eraseResults(resultsToErase);
}
/// Erases the results listed in `resultIndices`.
- void eraseResults(const BitVector &resultIndices) {
- Type newType = $_op.getTypeWithoutResults(resultIndices);
- function_interface_impl::eraseFunctionResults(
+ void eraseResults(const ::llvm::BitVector &resultIndices) {
+ ::mlir::Type newType = $_op.getTypeWithoutResults(resultIndices);
+ ::mlir::function_interface_impl::eraseFunctionResults(
$_op, resultIndices, newType);
}
@@ -328,13 +328,13 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// results inserted. This is used to update the function's signature in
/// the `insertArguments` and `insertResults` methods. The arrays must be
/// sorted by increasing index.
- Type getTypeWithArgsAndResults(
- ArrayRef<unsigned> argIndices, TypeRange argTypes,
- ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
- SmallVector<Type> argStorage, resultStorage;
- TypeRange newArgTypes = insertTypesInto(
+ ::mlir::Type getTypeWithArgsAndResults(
+ ::llvm::ArrayRef<unsigned> argIndices, ::mlir::TypeRange argTypes,
+ ::llvm::ArrayRef<unsigned> resultIndices, ::mlir::TypeRange resultTypes) {
+ ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage;
+ ::mlir::TypeRange newArgTypes = insertTypesInto(
$_op.getArgumentTypes(), argIndices, argTypes, argStorage);
- TypeRange newResultTypes = insertTypesInto(
+ ::mlir::TypeRange newResultTypes = insertTypesInto(
$_op.getResultTypes(), resultIndices, resultTypes, resultStorage);
return $_op.cloneTypeWith(newArgTypes, newResultTypes);
}
@@ -342,24 +342,24 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
/// Return the type of this function without the specified arguments and
/// results. This is used to update the function's signature in the
/// `eraseArguments` and `eraseResults` methods.
- Type getTypeWithoutArgsAndResults(
- const BitVector &argIndices, const BitVector &resultIndices) {
- SmallVector<Type> argStorage, resultStorage;
- TypeRange newArgTypes = filterTypesOut(
+ ::mlir::Type getTypeWithoutArgsAndResults(
+ const ::llvm::BitVector &argIndices, const ::llvm::BitVector &resultIndices) {
+ ::llvm::SmallVector<::mlir::Type> argStorage, resultStorage;
+ ::mlir::TypeRange newArgTypes = filterTypesOut(
$_op.getArgumentTypes(), argIndices, argStorage);
- TypeRange newResultTypes = filterTypesOut(
+ ::mlir::TypeRange newResultTypes = filterTypesOut(
$_op.getResultTypes(), resultIndices, resultStorage);
return $_op.cloneTypeWith(newArgTypes, newResultTypes);
}
- Type getTypeWithoutArgs(const BitVector &argIndices) {
- SmallVector<Type> argStorage;
- TypeRange newArgTypes = filterTypesOut(
+ ::mlir::Type getTypeWithoutArgs(const ::llvm::BitVector &argIndices) {
+ ::llvm::SmallVector<::mlir::Type> argStorage;
+ ::mlir::TypeRange newArgTypes = filterTypesOut(
$_op.getArgumentTypes(), argIndices, argStorage);
return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes());
}
- Type getTypeWithoutResults(const BitVector &resultIndices) {
- SmallVector<Type> resultStorage;
- TypeRange newResultTypes = filterTypesOut(
+ ::mlir::Type getTypeWithoutResults(const ::llvm::BitVector &resultIndices) {
+ ::llvm::SmallVector<::mlir::Type> resultStorage;
+ ::mlir::TypeRange newResultTypes = filterTypesOut(
$_op.getResultTypes(), resultIndices, resultStorage);
return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes);
}
@@ -369,88 +369,88 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
//===------------------------------------------------------------------===//
/// Return all of the attributes for the argument at 'index'.
- ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
- return function_interface_impl::getArgAttrs($_op, index);
+ ::llvm::ArrayRef<::mlir::NamedAttribute> getArgAttrs(unsigned index) {
+ return ::mlir::function_interface_impl::getArgAttrs($_op, index);
}
/// Return an ArrayAttr containing all argument attribute dictionaries of
/// this function, or nullptr if no arguments have attributes.
- ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
+ ::mlir::ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); }
/// Return all argument attributes of this function.
- void getAllArgAttrs(SmallVectorImpl<DictionaryAttr> &result) {
- if (ArrayAttr argAttrs = getAllArgAttrs()) {
- auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+ void getAllArgAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) {
+ if (::mlir::ArrayAttr argAttrs = getAllArgAttrs()) {
+ auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>();
result.append(argAttrRange.begin(), argAttrRange.end());
} else {
result.append($_op.getNumArguments(),
- DictionaryAttr::get(this->getOperation()->getContext()));
+ ::mlir::DictionaryAttr::get(this->getOperation()->getContext()));
}
}
/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
- Attribute getArgAttr(unsigned index, StringAttr name) {
+ ::mlir::Attribute getArgAttr(unsigned index, ::mlir::StringAttr name) {
auto argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
- Attribute getArgAttr(unsigned index, StringRef name) {
+ ::mlir::Attribute getArgAttr(unsigned index, ::llvm::StringRef name) {
auto argDict = getArgAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
template <typename AttrClass>
- AttrClass getArgAttrOfType(unsigned index, StringAttr name) {
+ AttrClass getArgAttrOfType(unsigned index, ::mlir::StringAttr name) {
return ::llvm::dyn_cast_or_null<AttrClass>(getArgAttr(index, name));
}
template <typename AttrClass>
- AttrClass getArgAttrOfType(unsigned index, StringRef name) {
+ AttrClass getArgAttrOfType(unsigned index, ::llvm::StringRef name) {
return ::llvm::dyn_cast_or_null<AttrClass>(getArgAttr(index, name));
}
/// Set the attributes held by the argument at 'index'.
- void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
- function_interface_impl::setArgAttrs($_op, index, attributes);
+ void setArgAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
+ ::mlir::function_interface_impl::setArgAttrs($_op, index, attributes);
}
/// Set the attributes held by the argument at 'index'. `attributes` may be
/// null, in which case any existing argument attributes are removed.
- void setArgAttrs(unsigned index, DictionaryAttr attributes) {
- function_interface_impl::setArgAttrs($_op, index, attributes);
+ void setArgAttrs(unsigned index, ::mlir::DictionaryAttr attributes) {
+ ::mlir::function_interface_impl::setArgAttrs($_op, index, attributes);
}
- void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
+ void setAllArgAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumArguments());
- function_interface_impl::setAllArgAttrDicts($_op, attributes);
+ ::mlir::function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
- void setAllArgAttrs(ArrayRef<Attribute> attributes) {
+ void setAllArgAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) {
assert(attributes.size() == $_op.getNumArguments());
- function_interface_impl::setAllArgAttrDicts($_op, attributes);
+ ::mlir::function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
- void setAllArgAttrs(ArrayAttr attributes) {
+ void setAllArgAttrs(::mlir::ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumArguments());
$_op.setArgAttrsAttr(attributes);
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
- void setArgAttr(unsigned index, StringAttr name, Attribute value) {
- function_interface_impl::setArgAttr($_op, index, name, value);
+ void setArgAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) {
+ ::mlir::function_interface_impl::setArgAttr($_op, index, name, value);
}
- void setArgAttr(unsigned index, StringRef name, Attribute value) {
+ void setArgAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) {
setArgAttr(index,
- StringAttr::get(this->getOperation()->getContext(), name),
+ ::mlir::StringAttr::get(this->getOperation()->getContext(), name),
value);
}
/// Remove the attribute 'name' from the argument at 'index'. Return the
/// attribute that was erased, or nullptr if there was no attribute with
/// such name.
- Attribute removeArgAttr(unsigned index, StringAttr name) {
- return function_interface_impl::removeArgAttr($_op, index, name);
+ ::mlir::Attribute removeArgAttr(unsigned index, ::mlir::StringAttr name) {
+ return ::mlir::function_interface_impl::removeArgAttr($_op, index, name);
}
- Attribute removeArgAttr(unsigned index, StringRef name) {
+ ::mlir::Attribute removeArgAttr(unsigned index, ::llvm::StringRef name) {
return removeArgAttr(
- index, StringAttr::get(this->getOperation()->getContext(), name));
+ index, ::mlir::StringAttr::get(this->getOperation()->getContext(), name));
}
//===------------------------------------------------------------------===//
@@ -458,102 +458,102 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [
//===------------------------------------------------------------------===//
/// Return all of the attributes for the result at 'index'.
- ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
- return function_interface_impl::getResultAttrs($_op, index);
+ ::llvm::ArrayRef<::mlir::NamedAttribute> getResultAttrs(unsigned index) {
+ return ::mlir::function_interface_impl::getResultAttrs($_op, index);
}
/// Return an ArrayAttr containing all result attribute dictionaries of this
/// function, or nullptr if no result have attributes.
- ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
+ ::mlir::ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); }
/// Return all result attributes of this function.
- void getAllResultAttrs(SmallVectorImpl<DictionaryAttr> &result) {
- if (ArrayAttr argAttrs = getAllResultAttrs()) {
- auto argAttrRange = argAttrs.template getAsRange<DictionaryAttr>();
+ void getAllResultAttrs(::llvm::SmallVectorImpl<::mlir::DictionaryAttr> &result) {
+ if (::mlir::ArrayAttr argAttrs = getAllResultAttrs()) {
+ auto argAttrRange = argAttrs.template getAsRange<::mlir::DictionaryAttr>();
result.append(argAttrRange.begin(), argAttrRange.end());
} else {
result.append($_op.getNumResults(),
- DictionaryAttr::get(this->getOperation()->getContext()));
+ ::mlir::DictionaryAttr::get(this->getOperation()->getContext()));
}
}
/// Return the specified attribute, if present, for the result at 'index',
/// null otherwise.
- Attribute getResultAttr(unsigned index, StringAttr name) {
+ ::mlir::Attribute getResultAttr(unsigned index, ::mlir::StringAttr name) {
auto argDict = getResultAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
- Attribute getResultAttr(unsigned index, StringRef name) {
+ ::mlir::Attribute getResultAttr(unsigned index, ::llvm::StringRef name) {
auto argDict = getResultAttrDict(index);
return argDict ? argDict.get(name) : nullptr;
}
template <typename AttrClass>
- AttrClass getResultAttrOfType(unsigned index, StringAttr name) {
+ AttrClass getResultAttrOfType(unsigned index, ::mlir::StringAttr name) {
return ::llvm::dyn_cast_or_null<AttrClass>(getResultAttr(index, name));
}
template <typename AttrClass>
- AttrClass getResultAttrOfType(unsigned index, StringRef name) {
+ AttrClass getResultAttrOfType(unsigned index, ::llvm::StringRef name) {
return ::llvm::dyn_cast_or_null<AttrClass>(getResultAttr(index, name));
}
/// Set the attributes held by the result at 'index'.
- void setResultAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
- function_interface_impl::setResultAttrs($_op, index, attributes);
+ void setResultAttrs(unsigned index, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) {
+ ::mlir::function_interface_impl::setResultAttrs($_op, index, attributes);
}
/// Set the attributes held by the result at 'index'. `attributes` may be
/// null, in which case any existing argument attributes are removed.
- void setResultAttrs(unsigned index, DictionaryAttr attributes) {
- function_interface_impl::setResultAttrs($_op, index, attributes);
+ void setResultAttrs(unsigned index, ::mlir::DictionaryAttr attributes) {
+ ::mlir::function_interface_impl::setResultAttrs($_op, index, attributes);
}
- void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
+ void setAllResultAttrs(::llvm::ArrayRef<::mlir::DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumResults());
- function_interface_impl::setAllResultAttrDicts(
+ ::mlir::function_interface_impl::setAllResultAttrDicts(
$_op, attributes);
}
- void setAllResultAttrs(ArrayRef<Attribute> attributes) {
+ void setAllResultAttrs(::llvm::ArrayRef<::mlir::Attribute> attributes) {
assert(attributes.size() == $_op.getNumResults());
- function_interface_impl::setAllResultAttrDicts(
+ ::mlir::function_interface_impl::setAllResultAttrDicts(
$_op, attributes);
}
- void setAllResultAttrs(ArrayAttr attributes) {
+ void setAllResultAttrs(::mlir::ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumResults());
$_op.setResAttrsAttr(attributes);
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
- void setResultAttr(unsigned index, StringAttr name, Attribute value) {
- function_interface_impl::setResultAttr($_op, index, name, value);
+ void setResultAttr(unsigned index, ::mlir::StringAttr name, ::mlir::Attribute value) {
+ ::mlir::function_interface_impl::setResultAttr($_op, index, name, value);
}
- void setResultAttr(unsigned index, StringRef name, Attribute value) {
+ void setResultAttr(unsigned index, ::llvm::StringRef name, ::mlir::Attribute value) {
setResultAttr(index,
- StringAttr::get(this->getOperation()->getContext(), name),
+ ::mlir::StringAttr::get(this->getOperation()->getContext(), name),
value);
}
/// Remove the attribute 'name' from the result at 'index'. Return the
/// attribute that was erased, or nullptr if there was no attribute with
/// such name.
- Attribute removeResultAttr(unsigned index, StringAttr name) {
- return function_interface_impl::removeResultAttr($_op, index, name);
+ ::mlir::Attribute removeResultAttr(unsigned index, ::mlir::StringAttr name) {
+ return ::mlir::function_interface_impl::removeResultAttr($_op, index, name);
}
/// Returns the dictionary attribute corresponding to the argument at
/// 'index'. If there are no argument attributes at 'index', a null
/// attribute is returned.
- DictionaryAttr getArgAttrDict(unsigned index) {
+ ::mlir::DictionaryAttr getArgAttrDict(unsigned index) {
assert(index < $_op.getNumArguments() && "invalid argument number");
- return function_interface_impl::getArgAttrDict($_op, index);
+ return ::mlir::function_interface_impl::getArgAttrDict($_op, index);
}
/// Returns the dictionary attribute corresponding to the result at 'index'.
/// If there are no result attributes at 'index', a null attribute is
/// returned.
- DictionaryAttr getResultAttrDict(unsigned index) {
+ ::mlir::DictionaryAttr getResultAttrDict(unsigned index) {
assert(index < $_op.getNumResults() && "invalid result number");
- return function_interface_impl::getResultAttrDict($_op, index);
+ return ::mlir::function_interface_impl::getResultAttrDict($_op, index);
}
}];
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 0d7722a..7e8e67a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,9 +24,11 @@ namespace mlir {
// Forward declarations.
class Attribute;
class Block;
+struct ConversionConfig;
class ConversionPatternRewriter;
class MLIRContext;
class Operation;
+struct OperationConverter;
class Type;
class Value;
@@ -657,12 +659,13 @@ struct ConversionPatternRewriterImpl;
/// hooks.
class ConversionPatternRewriter final : public PatternRewriter {
public:
- explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
/// Apply a signature conversion to the entry block of the given region. This
/// replaces the entry block with a new block containing the updated
/// signature. The new entry block to the region is returned for convenience.
+ /// If no block argument types are changing, the entry original block will be
+ /// left in place and returned.
///
/// If provided, `converter` will be used for any materializations.
Block *
@@ -671,8 +674,11 @@ public:
const TypeConverter *converter = nullptr);
/// Convert the types of block arguments within the given region. This
- /// replaces each block with a new block containing the updated signature. The
- /// entry block may have a special conversion if `entryConversion` is
+ /// replaces each block with a new block containing the updated signature. If
+ /// an updated signature would match the current signature, the respective
+ /// block is left in place as is.
+ ///
+ /// The entry block may have a special conversion if `entryConversion` is
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
FailureOr<Block *> convertRegionTypes(
@@ -681,7 +687,8 @@ public:
/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
- /// containing the updated signature.
+ /// containing the updated signature. If an updated signature would match the
+ /// current signature, the respective block is left in place as is.
///
/// If special conversion behavior is needed for the non-entry blocks (for
/// example, we need to convert only a subset of a BB arguments), such
@@ -758,6 +765,15 @@ public:
detail::ConversionPatternRewriterImpl &getImpl();
private:
+ // Allow OperationConverter to construct new rewriters.
+ friend struct OperationConverter;
+
+ /// Conversion pattern rewriters must not be used outside of dialect
+ /// conversions. They apply some IR rewrites in a delayed fashion and could
+ /// bring the IR into an inconsistent state when used standalone.
+ explicit ConversionPatternRewriter(MLIRContext *ctx,
+ const ConversionConfig &config);
+
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
@@ -1057,6 +1073,30 @@ public:
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
+// ConversionConfig
+//===----------------------------------------------------------------------===//
+
+/// Dialect conversion configuration.
+struct ConversionConfig {
+ /// An optional callback used to notify about match failure diagnostics during
+ /// the conversion. Diagnostics reported to this callback may only be
+ /// available in debug mode.
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr;
+
+ /// Partial conversion only. All operations that are found not to be
+ /// legalizable are placed in this set. (Note that if there is an op
+ /// explicitly marked as illegal, the conversion terminates and the set will
+ /// not necessarily be complete.)
+ DenseSet<Operation *> *unlegalizedOps = nullptr;
+
+ /// Analysis conversion only. All operations that are found to be legalizable
+ /// are placed in this set. Note that no actual rewrites are applied to the
+ /// IR during an analysis conversion and only pre-existing operations are
+ /// added to the set.
+ DenseSet<Operation *> *legalizableOps = nullptr;
+};
+
+//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
@@ -1069,19 +1109,16 @@ public:
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If an
-/// `unconvertedOps` set is provided, all operations that are found not to be
-/// legalizable to the given `target` are placed within that set. (Note that if
-/// there is an op explicitly marked as illegal, the conversion terminates and
-/// the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal.
LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+applyPartialConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
LogicalResult
applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
@@ -1089,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
-/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'. There's an additional argument
-/// `notifyCallback` which is used for collecting match failure diagnostics
-/// generated during the conversion. Diagnostics are only reported to this
-/// callback may only be available in debug mode.
-LogicalResult applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
-LogicalResult applyAnalysisConversion(
- Operation *op, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+/// provided 'config.legalizableOps' set; note that no actual rewrites are
+/// applied to the operations on success. This method only returns failure if
+/// there are unreachable blocks in any of the regions nested within 'ops'.
+LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
+LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
} // namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 1b72961..23e9572 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -148,10 +148,10 @@ struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -207,10 +207,10 @@ struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
@@ -266,10 +266,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
+ auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
auto splatAttr = SplatElementsAttr::get(
- mlir::VectorType::get(
- {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
- floatType),
+ mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
+ {numElements.isScalable()}),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7eb32eb..7c477f2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -384,23 +384,23 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
- int32_t min = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue());
- int32_t max = static_cast<int32_t>(
- cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue());
+ int64_t min =
+ cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
+ int64_t max =
+ cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
if (intTy.isUnsignedInteger()) {
- min = std::max<int32_t>(min, 0);
- max = std::min<int32_t>(
+ min = std::max(min, (int64_t)0);
+ max = std::min(
max,
APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
} else {
- min = std::max<int32_t>(
- min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
- max = std::min<int32_t>(
- max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
- .getSExtValue());
+ min =
+ std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
+ max =
+ std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
+ .getSExtValue());
}
auto minVal = rewriter.create<arith::ConstantIntOp>(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index e88f82c..26dfb38 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -40,12 +40,12 @@ namespace {
//===----------------------------------------------------------------------===//
// Common match failure reasons.
-static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
+static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
"op vector size is not multiple of SME tiles");
-static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
+static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
"op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
- MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
+ kMatchFailureNonPermutationMap("op affine map is not a permutation");
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -174,8 +174,8 @@ struct LegalizeVectorOuterProductOpsByDecomposition
OneToNPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
- return rewriter.notifyMatchFailure(
- outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+ return rewriter.notifyMatchFailure(outerProductOp,
+ kMatchFailureNotSMETileTypeMultiple);
Value mask;
Operation *rootOp = outerProductOp;
@@ -188,7 +188,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(outerProductOp,
- MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+ kMatchFailureUnsupportedMaskOp);
ValueRange accSMETiles = adaptor.getAcc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -252,18 +252,18 @@ struct LegalizeTransferReadOpsByDecomposition
OneToNPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
- return rewriter.notifyMatchFailure(
- readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+ return rewriter.notifyMatchFailure(readOp,
+ kMatchFailureNotSMETileTypeMultiple);
auto mask = readOp.getMask();
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(readOp,
- MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+ kMatchFailureUnsupportedMaskOp);
auto permutationMap = readOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(readOp,
- MATCH_FAILURE_NON_PERMUTATION_MAP);
+ kMatchFailureNonPermutationMap);
// Note: For 2D vector types the only non-identity permutation is a simple
// tranpose [1, 0].
@@ -300,18 +300,18 @@ struct LegalizeTransferWriteOpsByDecomposition
OneToNPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
- return rewriter.notifyMatchFailure(
- writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNotSMETileTypeMultiple);
auto mask = writeOp.getMask();
if (!isSupportedMaskOp(mask))
return rewriter.notifyMatchFailure(writeOp,
- MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+ kMatchFailureUnsupportedMaskOp);
auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(writeOp,
- MATCH_FAILURE_NON_PERMUTATION_MAP);
+ kMatchFailureNonPermutationMap);
// Note: For 2D vector types the only non-identity permutation is a simple
// tranpose [1, 0].
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index e5776e1..51cfa22 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -1,11 +1,3 @@
-if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
- set(NVPTX_LIBS
- NVPTXCodeGen
- NVPTXDesc
- NVPTXInfo
- )
-endif()
-
if (MLIR_ENABLE_ROCM_CONVERSIONS)
set(AMDGPU_LIBS
IRReader
@@ -60,7 +52,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ParallelLoopMapper.cpp
Transforms/ROCDLAttachTarget.cpp
Transforms/SerializeToBlob.cpp
- Transforms/SerializeToCubin.cpp
Transforms/SerializeToHsaco.cpp
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
@@ -74,7 +65,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
Core
MC
Target
- ${NVPTX_LIBS}
${AMDGPU_LIBS}
DEPENDS
@@ -110,48 +100,6 @@ add_mlir_dialect_library(MLIRGPUTransforms
add_subdirectory(TransformOps)
add_subdirectory(Pipelines)
-if(MLIR_ENABLE_CUDA_RUNNER)
- if(NOT MLIR_ENABLE_CUDA_CONVERSIONS)
- message(SEND_ERROR
- "Building mlir with cuda support requires the NVPTX backend")
- endif()
-
- # Configure CUDA language support. Using check_language first allows us to
- # give a custom error message.
- include(CheckLanguage)
- check_language(CUDA)
- if (CMAKE_CUDA_COMPILER)
- enable_language(CUDA)
- else()
- message(SEND_ERROR
- "Building mlir with cuda support requires a working CUDA install")
- endif()
-
- # Enable gpu-to-cubin pass.
- target_compile_definitions(obj.MLIRGPUTransforms
- PRIVATE
- MLIR_GPU_TO_CUBIN_PASS_ENABLE=1
- )
-
- # Add CUDA headers includes and the libcuda.so library.
- target_include_directories(obj.MLIRGPUTransforms
- PRIVATE
- ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
- )
-
- # Add link path for the cuda driver library.
- find_library(CUDA_DRIVER_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED)
- get_filename_component(CUDA_DRIVER_LIBRARY_PATH "${CUDA_DRIVER_LIBRARY}" DIRECTORY)
- target_link_directories(MLIRGPUTransforms PRIVATE ${CUDA_DRIVER_LIBRARY_PATH})
-
- target_link_libraries(MLIRGPUTransforms
- PRIVATE
- MLIRNVVMToLLVMIRTranslation
- cuda
- )
-
-endif()
-
if(MLIR_ENABLE_ROCM_CONVERSIONS)
if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD))
message(SEND_ERROR
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 30b6cd7..33ce5c1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -648,6 +648,8 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
TypeRange workgroupAttributions,
TypeRange privateAttributions, Value clusterSizeX,
Value clusterSizeY, Value clusterSizeZ) {
+ OpBuilder::InsertionGuard g(builder);
+
// Add a WorkGroup attribution attribute. This attribute is required to
// identify private attributions in the list of block argguments.
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
@@ -674,7 +676,7 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
// attributions, where the first kNumConfigRegionAttributes arguments have
// `index` type and the rest have the same types as the data operands.
Region *kernelRegion = result.addRegion();
- Block *body = new Block();
+ Block *body = builder.createBlock(kernelRegion);
// TODO: Allow passing in proper locations here.
for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
body->addArgument(builder.getIndexType(), result.location);
@@ -683,7 +685,6 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
body->addArgument(argTy, result.location);
for (Type argTy : privateAttributions)
body->addArgument(argTy, result.location);
- kernelRegion->push_back(body);
// Fill OperandSegmentSize Attribute.
SmallVector<int32_t, 11> segmentSizes(11, 1);
segmentSizes.front() = asyncDependencies.size();
@@ -1325,6 +1326,8 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
TypeRange workgroupAttributions,
TypeRange privateAttributions,
ArrayRef<NamedAttribute> attrs) {
+ OpBuilder::InsertionGuard g(builder);
+
result.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
result.addAttribute(getFunctionTypeAttrName(result.name),
@@ -1333,7 +1336,7 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
builder.getI64IntegerAttr(workgroupAttributions.size()));
result.addAttributes(attrs);
Region *body = result.addRegion();
- Block *entryBlock = new Block;
+ Block *entryBlock = builder.createBlock(body);
// TODO: Allow passing in proper locations here.
for (Type argTy : type.getInputs())
@@ -1342,8 +1345,6 @@ void GPUFuncOp::build(OpBuilder &builder, OperationState &result,
entryBlock->addArgument(argTy, result.location);
for (Type argTy : privateAttributions)
entryBlock->addArgument(argTy, result.location);
-
- body->getBlocks().push_back(entryBlock);
}
/// Parses a GPU function memory attribution.
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
deleted file mode 100644
index 34ad4e6..0000000
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
+++ /dev/null
@@ -1,180 +0,0 @@
-//===- LowerGPUToCUBIN.cpp - Convert GPU kernel to CUBIN blob -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass that serializes a gpu module into CUBIN blob and
-// adds that blob as a string attribute of the module.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
-#include "llvm/Support/Debug.h"
-
-#if MLIR_GPU_TO_CUBIN_PASS_ENABLE
-#include "mlir/Pass/Pass.h"
-#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Export.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Support/Threading.h"
-
-#include <cuda.h>
-
-using namespace mlir;
-
-static void emitCudaError(const llvm::Twine &expr, const char *buffer,
- CUresult result, Location loc) {
- const char *error = nullptr;
- cuGetErrorString(result, &error);
- emitError(loc,
- expr.concat(error ? " failed with error code " + llvm::Twine{error}
- : llvm::Twine(" failed with unknown error "))
- .concat("[")
- .concat(buffer)
- .concat("]"));
-}
-
-#define RETURN_ON_CUDA_ERROR(expr) \
- do { \
- if (auto status = (expr)) { \
- emitCudaError(#expr, jitErrorBuffer, status, loc); \
- return {}; \
- } \
- } while (false)
-
-namespace {
-class SerializeToCubinPass
- : public PassWrapper<SerializeToCubinPass, gpu::SerializeToBlobPass> {
- static llvm::once_flag initializeBackendOnce;
-
-public:
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToCubinPass)
-
- SerializeToCubinPass(StringRef triple = "nvptx64-nvidia-cuda",
- StringRef chip = "sm_35", StringRef features = "+ptx60",
- int optLevel = 2, bool dumpPtx = false);
-
- StringRef getArgument() const override { return "gpu-to-cubin"; }
- StringRef getDescription() const override {
- return "Lower GPU kernel function to CUBIN binary annotations";
- }
-
-private:
- // Serializes PTX to CUBIN.
- std::unique_ptr<std::vector<char>>
- serializeISA(const std::string &isa) override;
-};
-} // namespace
-
-// Sets the 'option' to 'value' unless it already has a value.
-static void maybeSetOption(Pass::Option<std::string> &option, StringRef value) {
- if (!option.hasValue())
- option = value.str();
-}
-
-llvm::once_flag SerializeToCubinPass::initializeBackendOnce;
-
-SerializeToCubinPass::SerializeToCubinPass(StringRef triple, StringRef chip,
- StringRef features, int optLevel,
- bool dumpPtx) {
- // No matter how this pass is constructed, ensure that the NVPTX backend
- // is initialized exactly once.
- llvm::call_once(initializeBackendOnce, []() {
- // Initialize LLVM NVPTX backend.
-#if LLVM_HAS_NVPTX_TARGET
- LLVMInitializeNVPTXTarget();
- LLVMInitializeNVPTXTargetInfo();
- LLVMInitializeNVPTXTargetMC();
- LLVMInitializeNVPTXAsmPrinter();
-#endif
- });
-
- maybeSetOption(this->triple, triple);
- maybeSetOption(this->chip, chip);
- maybeSetOption(this->features, features);
- this->dumpPtx = dumpPtx;
- if (this->optLevel.getNumOccurrences() == 0)
- this->optLevel.setValue(optLevel);
-}
-
-std::unique_ptr<std::vector<char>>
-SerializeToCubinPass::serializeISA(const std::string &isa) {
- Location loc = getOperation().getLoc();
- char jitErrorBuffer[4096] = {0};
-
- RETURN_ON_CUDA_ERROR(cuInit(0));
-
- // Linking requires a device context.
- CUdevice device;
- RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));
- CUcontext context;
- // Use the primary context.
- RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRetain(&context, device));
- // Push the primary context so that the next CUDA operations
- // actually use it.
- RETURN_ON_CUDA_ERROR(cuCtxPushCurrent(context));
- CUlinkState linkState;
-
- CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
- CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
- void *jitOptionsVals[] = {jitErrorBuffer,
- reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
-
- RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */
- jitOptions, /* jit options */
- jitOptionsVals, /* jit option values */
- &linkState));
-
- auto kernelName = getOperation().getName().str();
- if (dumpPtx) {
- llvm::dbgs() << " Kernel Name : [" << kernelName << "]\n";
- llvm::dbgs() << isa << "\n";
- }
- RETURN_ON_CUDA_ERROR(cuLinkAddData(
- linkState, CUjitInputType::CU_JIT_INPUT_PTX,
- const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(),
- kernelName.c_str(), 0, /* number of jit options */
- nullptr, /* jit options */
- nullptr /* jit option values */
- ));
-
- void *cubinData;
- size_t cubinSize;
- RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));
-
- char *cubinAsChar = static_cast<char *>(cubinData);
- auto result =
- std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
-
- // This will also destroy the cubin data.
- RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));
- // Pop and release the primary context.
- CUcontext poppedContext;
- RETURN_ON_CUDA_ERROR(cuCtxPopCurrent(&poppedContext));
- RETURN_ON_CUDA_ERROR(cuDevicePrimaryCtxRelease(device));
-
- return result;
-}
-
-// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
-void mlir::registerGpuSerializeToCubinPass() {
- PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
- [] { return std::make_unique<SerializeToCubinPass>(); });
-}
-
-std::unique_ptr<Pass> mlir::createGpuSerializeToCubinPass(StringRef triple,
- StringRef arch,
- StringRef features,
- int optLevel,
- bool dumpPtx) {
- return std::make_unique<SerializeToCubinPass>(triple, arch, features,
- optLevel, dumpPtx);
-}
-
-#else // MLIR_GPU_TO_CUBIN_PASS_ENABLE
-void mlir::registerGpuSerializeToCubinPass() {}
-#endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 7eed792..3627ff6 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
LinalgOp linalgOp = cast<LinalgOp>(op);
+ // Mixed tensor/buffer operands are not allowed.
+ if (!linalgOp.hasPureTensorSemantics() &&
+ !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
+ return op->emitOpError("expected to have pure tensor or buffer semantics");
+
// Before checking indexing maps, we need to make sure the attributes
// referenced by it are valid.
if (linalgOp.hasDynamicIndexingMaps())
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 71e4e13..962cb28 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1471,6 +1471,16 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
//----------------------------------------------------------------------------//
+void mlir::populatePolynomialApproximateTanhPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<TanhApproximation>(patterns.getContext());
+}
+
+void mlir::populatePolynomialApproximateErfPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<ErfPolynomialApproximation>(patterns.getContext());
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
index 7cbe0de..c4d8b0b 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -593,7 +593,6 @@ static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) {
Operation *definingOp = operand.getDefiningOp();
assert(definingOp);
ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
- assert(shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
@@ -615,34 +614,46 @@ static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) {
assert(result.hasOneUse());
Operation *userOp = *result.getUsers().begin();
ShardOp shardOp = llvm::cast<ShardOp>(userOp);
- assert(!shardOp.getAnnotateForUsers());
return shardOp.getShard();
});
return res;
}
static LogicalResult
-spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) {
- ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
- if (shardOp) {
- if (!shardOp.getAnnotateForUsers()) {
- return success();
- }
-
+ Value targetSpmdValue;
+
+ // Check if 2 shard ops are chained. If not there is no need for resharding
+ // as the source and target shared the same sharding.
+ ShardOp srcShardOp =
+ dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
+ if (!srcShardOp) {
+ targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand());
+ } else {
// Insert resharding.
- ShardOp srcShardOp =
- llvm::cast<ShardOp>(shardOp.getOperand().getDefiningOp());
- assert(!srcShardOp.getAnnotateForUsers());
+ assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
TypedValue<ShapedType> srcSpmdValue =
spmdizationMap.lookup(srcShardOp.getOperand())
.cast<TypedValue<ShapedType>>();
- Value targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
- symbolTableCollection);
- assert(!spmdizationMap.contains(shardOp.getResult()));
- spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
- return success();
+ targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue,
+ symbolTableCollection);
+ }
+
+ assert(!spmdizationMap.contains(shardOp.getResult()));
+ spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
+ return success();
+}
+
+static LogicalResult
+spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTableCollection,
+ OpBuilder &builder) {
+ ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
+ if (shardOp) {
+ return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
+ builder);
}
SmallVector<Value> spmdizedOperands;
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 119df9a..233e702 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -306,17 +306,18 @@ void ConditionOp::getSuccessorRegions(
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, ValueRange iterArgs,
BodyBuilderFn bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+
result.addOperands({lb, ub, step});
result.addOperands(iterArgs);
for (Value v : iterArgs)
result.addTypes(v.getType());
Type t = lb.getType();
Region *bodyRegion = result.addRegion();
- bodyRegion->push_back(new Block);
- Block &bodyBlock = bodyRegion->front();
- bodyBlock.addArgument(t, result.location);
+ Block *bodyBlock = builder.createBlock(bodyRegion);
+ bodyBlock->addArgument(t, result.location);
for (Value v : iterArgs)
- bodyBlock.addArgument(v.getType(), v.getLoc());
+ bodyBlock->addArgument(v.getType(), v.getLoc());
// Create the default terminator if the builder is not provided and if the
// iteration arguments are not provided. Otherwise, leave this to the caller
@@ -325,9 +326,9 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
ForOp::ensureTerminator(*bodyRegion, builder, result.location);
} else if (bodyBuilder) {
OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(&bodyBlock);
- bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
- bodyBlock.getArguments().drop_front());
+ builder.setInsertionPointToStart(bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock->getArgument(0),
+ bodyBlock->getArguments().drop_front());
}
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 950ee59..62d0785 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1330,6 +1330,7 @@ NARY_SHAPE_INFER(tosa::CastOp)
NARY_SHAPE_INFER(tosa::CeilOp)
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
+NARY_SHAPE_INFER(tosa::CosOp)
NARY_SHAPE_INFER(tosa::DivOp)
NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
@@ -1352,6 +1353,7 @@ NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
+NARY_SHAPE_INFER(tosa::SinOp)
NARY_SHAPE_INFER(tosa::SelectOp)
NARY_SHAPE_INFER(tosa::SubOp)
NARY_SHAPE_INFER(tosa::TanhOp)
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index baaa581..4c96065 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -7,13 +7,12 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/IndexingUtils.h"
-
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/STLExtras.h"
-
#include <numeric>
#include <optional>
@@ -306,6 +305,14 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
return {expr, values};
}
+std::pair<AffineExpr, SmallVector<OpFoldResult>>
+mlir::computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<int64_t> strides,
+ ArrayRef<Value> indices) {
+ return computeLinearIndex(
+ sourceOffset, getAsIndexOpFoldResult(sourceOffset.getContext(), strides),
+ getAsOpFoldResult(ValueRange(indices)));
+}
+
//===----------------------------------------------------------------------===//
// TileOffsetRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb667..fc11ae6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
- if (!preconditionType || preconditionType.getRank() != 1 ||
- preconditionType.isScalable())
- return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+ if (!preconditionType || preconditionType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");
+ if (!preconditionType || preconditionType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
return commonConversionPrecondition(rewriter, preconditionType, op);
}
@@ -855,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
"Expected i4 type");
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
- int64_t vecDimSize = srcVecType.getShape().back();
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
@@ -871,16 +872,8 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
- // 3. Interleave low and high i8 elements using a shuffle.
- SmallVector<int64_t> interleaveMaskValues;
- interleaveMaskValues.reserve(vecDimSize);
- for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
- interleaveMaskValues.push_back(i);
- interleaveMaskValues.push_back(i + (vecDimSize / 2));
- }
-
- return rewriter.create<vector::ShuffleOp>(
- loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+ // 3. Interleave low and high i8 elements.
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
namespace {
@@ -1008,8 +1001,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1010,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 04e5a81..0ffef6a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -577,7 +578,6 @@ public:
if (transferReadOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices;
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// 1. Collapse the source memref
@@ -599,12 +599,14 @@ public:
// 2.2 New indices
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
+ SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstDimToCollapse,
collapsedIndices))) {
- // Copy all the leading indices
- collapsedIndices = transferReadOp.getIndices();
- collapsedIndices.resize(firstDimToCollapse);
+ // Copy all the leading indices.
+ SmallVector<Value> indices = transferReadOp.getIndices();
+ collapsedIndices.append(indices.begin(),
+ indices.begin() + firstDimToCollapse);
// Compute the remaining trailing index/offset required for reading from
// the collapsed memref:
@@ -621,24 +623,26 @@ public:
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
- AffineExpr offsetExpr, idxExpr;
- bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);
-
- int64_t outputRank = transferReadOp.getIndices().size();
- OpFoldResult offset =
+ OpFoldResult collapsedOffset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
- for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
- int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
- offset = affine::makeComposedFoldedAffineApply(
- rewriter, loc, offsetExpr + dim * idxExpr,
- {offset, transferReadOp.getIndices()[i]});
- }
- if (offset.is<Value>()) {
- collapsedIndices.push_back(offset.get<Value>());
+ auto sourceShape = sourceType.getShape();
+ auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>(
+ sourceShape.begin() + firstDimToCollapse, sourceShape.end()));
+
+ // Compute the collapsed offset.
+ ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
+ indices.end());
+ auto &&[collapsedExpr, collapsedVals] = computeLinearIndex(
+ collapsedOffset, collapsedStrides, indicesToCollapse);
+ collapsedOffset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, collapsedExpr, collapsedVals);
+
+ if (collapsedOffset.is<Value>()) {
+ collapsedIndices.push_back(collapsedOffset.get<Value>());
} else {
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
- loc, *getConstantIntValue(offset)));
+ loc, *getConstantIntValue(collapsedOffset)));
}
}
@@ -710,6 +714,7 @@ public:
firstContiguousInnerDim,
collapsedIndices)))
return failure();
+
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
MemRefType collapsedSourceType =
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 9e8b240..bbe10b0 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -74,13 +74,6 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETVALUES)
#undef IMPL_GETVALUES
-#define IMPL_FORWARDINGINSERT(VNAME, V) \
- void SparseTensorStorageBase::forwardingInsert(const uint64_t *, V) { \
- FATAL_PIV("forwardingInsert" #VNAME); \
- }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
-#undef IMPL_FORWARDINGINSERT
-
#define IMPL_LEXINSERT(VNAME, V) \
void SparseTensorStorageBase::lexInsert(const uint64_t *, V) { \
FATAL_PIV("lexInsert" #VNAME); \
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index a5e75a7..0bc90b2 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -117,20 +117,7 @@ extern "C" {
switch (action) { \
case Action::kEmpty: { \
return SparseTensorStorage<P, C, V>::newEmpty( \
- dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
- false); \
- } \
- case Action::kEmptyForward: { \
- return SparseTensorStorage<P, C, V>::newEmpty( \
- dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
- true); \
- } \
- case Action::kFromCOO: { \
- assert(ptr && "Received nullptr for SparseTensorCOO object"); \
- auto &coo = *static_cast<SparseTensorCOO<V> *>(ptr); \
- return SparseTensorStorage<P, C, V>::newFromCOO( \
- dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
- coo); \
+ dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim); \
} \
case Action::kFromReader: { \
assert(ptr && "Received nullptr for SparseTensorReader object"); \
@@ -138,11 +125,6 @@ extern "C" {
return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
} \
- case Action::kToCOO: { \
- assert(ptr && "Received nullptr for SparseTensorStorage object"); \
- auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
- return tensor.toCOO(); \
- } \
case Action::kPack: { \
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
intptr_t *buffers = static_cast<intptr_t *>(ptr); \
@@ -341,21 +323,6 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
#undef IMPL_SPARSECOORDINATES
#undef IMPL_GETOVERHEAD
-#define IMPL_FORWARDINGINSERT(VNAME, V) \
- void _mlir_ciface_forwardingInsert##VNAME( \
- void *t, StridedMemRefType<V, 0> *vref, \
- StridedMemRefType<index_type, 1> *dimCoordsRef) { \
- assert(t &&vref); \
- ASSERT_NO_STRIDE(dimCoordsRef); \
- const index_type *dimCoords = MEMREF_GET_PAYLOAD(dimCoordsRef); \
- assert(dimCoords); \
- const V *value = MEMREF_GET_PAYLOAD(vref); \
- static_cast<SparseTensorStorageBase *>(t)->forwardingInsert(dimCoords, \
- *value); \
- }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_FORWARDINGINSERT)
-#undef IMPL_FORWARDINGINSERT
-
#define IMPL_LEXINSERT(VNAME, V) \
void _mlir_ciface_lexInsert##VNAME( \
void *t, StridedMemRefType<index_type, 1> *lvlCoordsRef, \
@@ -427,8 +394,8 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
const uint64_t cSize = MEMREF_GET_USIZE(cref); \
const uint64_t vSize = MEMREF_GET_USIZE(vref); \
ASSERT_USIZE_EQ(lvl2dimRef, dimRank); \
- assert(cSize >= lvlRank * vSize); \
- assert(vSize >= reader.getNSE() && "Not enough space in buffers"); \
+ assert(cSize >= lvlRank * reader.getNSE()); \
+ assert(vSize >= reader.getNSE()); \
(void)dimRank; \
(void)cSize; \
(void)vSize; \
@@ -488,10 +455,6 @@ index_type sparseDimSize(void *tensor, index_type d) {
return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
}
-void endForwardingInsert(void *tensor) {
- return static_cast<SparseTensorStorageBase *>(tensor)->endForwardingInsert();
-}
-
void endLexInsert(void *tensor) {
return static_cast<SparseTensorStorageBase *>(tensor)->endLexInsert();
}
@@ -500,13 +463,6 @@ void delSparseTensor(void *tensor) {
delete static_cast<SparseTensorStorageBase *>(tensor);
}
-#define IMPL_DELCOO(VNAME, V) \
- void delSparseTensorCOO##VNAME(void *coo) { \
- delete static_cast<SparseTensorCOO<V> *>(coo); \
- }
-MLIR_SPARSETENSOR_FOREVERY_V(IMPL_DELCOO)
-#undef IMPL_DELCOO
-
char *getTensorFilename(index_type id) {
constexpr size_t bufSize = 80;
char var[bufSize];
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 6521295..c631617 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -99,21 +99,31 @@ DIFileAttr DebugImporter::translateImpl(llvm::DIFile *node) {
}
DILabelAttr DebugImporter::translateImpl(llvm::DILabel *node) {
- return DILabelAttr::get(context, translate(node->getScope()),
+ // Return nullptr if the scope or type is a cyclic dependency.
+ DIScopeAttr scope = translate(node->getScope());
+ if (node->getScope() && !scope)
+ return nullptr;
+ return DILabelAttr::get(context, scope,
getStringAttrOrNull(node->getRawName()),
translate(node->getFile()), node->getLine());
}
DILexicalBlockAttr DebugImporter::translateImpl(llvm::DILexicalBlock *node) {
- return DILexicalBlockAttr::get(context, translate(node->getScope()),
- translate(node->getFile()), node->getLine(),
- node->getColumn());
+ // Return nullptr if the scope or type is a cyclic dependency.
+ DIScopeAttr scope = translate(node->getScope());
+ if (node->getScope() && !scope)
+ return nullptr;
+ return DILexicalBlockAttr::get(context, scope, translate(node->getFile()),
+ node->getLine(), node->getColumn());
}
DILexicalBlockFileAttr
DebugImporter::translateImpl(llvm::DILexicalBlockFile *node) {
- return DILexicalBlockFileAttr::get(context, translate(node->getScope()),
- translate(node->getFile()),
+ // Return nullptr if the scope or type is a cyclic dependency.
+ DIScopeAttr scope = translate(node->getScope());
+ if (node->getScope() && !scope)
+ return nullptr;
+ return DILexicalBlockFileAttr::get(context, scope, translate(node->getFile()),
node->getDiscriminator());
}
@@ -135,11 +145,14 @@ DebugImporter::translateImpl(llvm::DIGlobalVariable *node) {
}
DILocalVariableAttr DebugImporter::translateImpl(llvm::DILocalVariable *node) {
- return DILocalVariableAttr::get(context, translate(node->getScope()),
- getStringAttrOrNull(node->getRawName()),
- translate(node->getFile()), node->getLine(),
- node->getArg(), node->getAlignInBits(),
- translate(node->getType()));
+ // Return nullptr if the scope or type is a cyclic dependency.
+ DIScopeAttr scope = translate(node->getScope());
+ if (node->getScope() && !scope)
+ return nullptr;
+ return DILocalVariableAttr::get(
+ context, scope, getStringAttrOrNull(node->getRawName()),
+ translate(node->getFile()), node->getLine(), node->getArg(),
+ node->getAlignInBits(), translate(node->getType()));
}
DIScopeAttr DebugImporter::translateImpl(llvm::DIScope *node) {
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 97ccb2b..d63ea12 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1966,6 +1966,13 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
// TODO: find a way to support this case.
if (isMetadataKillLocation(dbgIntr))
return emitUnsupportedWarning();
+ // Drop debug intrinsics if the associated variable information cannot be
+ // translated due to cyclic debug metadata.
+ // TODO: Support cyclic debug metadata.
+ DILocalVariableAttr localVariableAttr =
+ matchLocalVariableAttr(dbgIntr->getArgOperand(1));
+ if (!localVariableAttr)
+ return emitUnsupportedWarning();
FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0));
if (failed(argOperand))
return emitError(loc) << "failed to convert a debug intrinsic operand: "
@@ -1991,8 +1998,6 @@ ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
} else {
builder.setInsertionPointAfterValue(*argOperand);
}
- DILocalVariableAttr localVariableAttr =
- matchLocalVariableAttr(dbgIntr->getArgOperand(1));
auto locationExprAttr =
debugImporter->translateExpression(dbgIntr->getExpression());
Operation *op =
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4989ddc..857b601 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,519 +152,25 @@ namespace {
/// This class contains a snapshot of the current conversion rewriter state.
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
- RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
- unsigned numReplacements, unsigned numArgReplacements,
- unsigned numRewrites, unsigned numIgnoredOperations)
- : numCreatedOps(numCreatedOps),
- numUnresolvedMaterializations(numUnresolvedMaterializations),
- numReplacements(numReplacements),
- numArgReplacements(numArgReplacements), numRewrites(numRewrites),
- numIgnoredOperations(numIgnoredOperations) {}
-
- /// The current number of created operations.
- unsigned numCreatedOps;
-
- /// The current number of unresolved materializations.
- unsigned numUnresolvedMaterializations;
-
- /// The current number of replacements queued.
- unsigned numReplacements;
-
- /// The current number of argument replacements queued.
- unsigned numArgReplacements;
+ RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+ unsigned numErased)
+ : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+ numErased(numErased) {}
/// The current number of rewrites performed.
unsigned numRewrites;
/// The current number of ignored operations.
unsigned numIgnoredOperations;
-};
-
-//===----------------------------------------------------------------------===//
-// OpReplacement
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
- OpReplacement(const TypeConverter *converter = nullptr)
- : converter(converter) {}
-
- /// An optional type converter that can be used to materialize conversions
- /// between the new and old values if necessary.
- const TypeConverter *converter;
-};
-
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
- /// The type of materialization.
- enum Kind {
- /// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
- Argument,
-
- /// This materialization materializes a conversion from an illegal type to a
- /// legal one.
- Target
- };
-
- UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
- const TypeConverter *converter = nullptr,
- Kind kind = Target, Type origOutputType = nullptr)
- : op(op), converterAndKind(converter, kind),
- origOutputType(origOutputType) {}
-
- /// Return the temporary conversion operation inserted for this
- /// materialization.
- UnrealizedConversionCastOp getOp() const { return op; }
- /// Return the type converter of this materialization (which may be null).
- const TypeConverter *getConverter() const {
- return converterAndKind.getPointer();
- }
-
- /// Return the kind of this materialization.
- Kind getKind() const { return converterAndKind.getInt(); }
-
- /// Set the kind of this materialization.
- void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
-
-private:
- /// The unresolved materialization operation created during conversion.
- UnrealizedConversionCastOp op;
-
- /// The corresponding type converter to use when resolving this
- /// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
- /// The original output type. This is only used for argument conversions.
- Type origOutputType;
+ /// The current number of erased operations/blocks.
+ unsigned numErased;
};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
- UnresolvedMaterialization::Kind kind, Block *insertBlock,
- Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
- Type origOutputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- // Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
- return inputs.front();
-
- // Create an unresolved materialization. We use a new OpBuilder to avoid
- // tracking the materialization like we do for other operations.
- OpBuilder builder(insertBlock, insertPt);
- auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- unresolvedMaterializations.emplace_back(convertOp, converter, kind,
- origOutputType);
- return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
- rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
- converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType, const TypeConverter *converter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
- Block *insertBlock = input.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(input))
- insertPt = ++inputRes.getOwner()->getIterator();
-
- return buildUnresolvedMaterialization(
- UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
- outputType, outputType, converter, unresolvedMaterializations);
-}
-
-//===----------------------------------------------------------------------===//
-// ArgConverter
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by creating a new block that contains the new legal
-/// types and extracting the block that contains the old illegal types to allow
-/// for undoing pending rewrites in the case of failure.
-struct ArgConverter {
- ArgConverter(
- PatternRewriter &rewriter,
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations)
- : rewriter(rewriter),
- unresolvedMaterializations(unresolvedMaterializations) {}
-
- /// This structure contains the information pertaining to an argument that has
- /// been converted.
- struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
- };
-
- /// This structure contains information pertaining to a block that has had its
- /// signature converted.
- struct ConvertedBlockInfo {
- ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter)
- : origBlock(origBlock), converter(converter) {}
-
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
-
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
- /// The type converter used to convert the arguments.
- const TypeConverter *converter;
- };
-
- //===--------------------------------------------------------------------===//
- // Rewrite Application
- //===--------------------------------------------------------------------===//
-
- /// Erase any rewrites registered for the blocks within the given operation
- /// which is about to be removed. This merely drops the rewrites without
- /// undoing them.
- void notifyOpRemoved(Operation *op);
-
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method replaces the new block with the original, reverting the IR to
- /// its original state.
- void discardRewrites(Block *block);
-
- /// Fully replace uses of the old arguments with the new.
- void applyRewrites(ConversionValueMapping &mapping);
-
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(ConversionValueMapping &mapping,
- OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser);
-
- //===--------------------------------------------------------------------===//
- // Conversion
- //===--------------------------------------------------------------------===//
-
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *>
- convertSignature(Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// Apply the given signature conversion on the given block. The new block
- /// containing the updated signature is returned. If no conversions were
- /// necessary, e.g. if the block has no arguments, `block` is returned.
- /// `converter` is used to generate any necessary cast operations that
- /// translate between the origin argument types and those specified in the
- /// signature conversion.
- Block *applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements);
-
- /// A collection of blocks that have had their arguments converted. This is a
- /// map from the new replacement block, back to the original block.
- llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
-
- /// The pattern rewriter to use when materializing conversions.
- PatternRewriter &rewriter;
-
- /// An ordered set of unresolved materializations during conversion.
- SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// Rewrite Application
-
-void ArgConverter::notifyOpRemoved(Operation *op) {
- if (conversionInfo.empty())
- return;
-
- for (Region &region : op->getRegions()) {
- for (Block &block : region) {
- // Drop any rewrites from within.
- for (Operation &nestedOp : block)
- if (nestedOp.getNumRegions())
- notifyOpRemoved(&nestedOp);
-
- // Check if this block was converted.
- auto *it = conversionInfo.find(&block);
- if (it == conversionInfo.end())
- continue;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument arg : origBlock->getArguments())
- arg.dropAllUses();
- conversionInfo.erase(it);
- }
- }
-}
-
-void ArgConverter::discardRewrites(Block *block) {
- auto *it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
- Block *origBlock = it->second.origBlock;
-
- // Drop all uses of the new block arguments and replace uses of the new block.
- for (int i = block->getNumArguments() - 1; i >= 0; --i)
- block->getArgument(i).dropAllUses();
- block->replaceAllUsesWith(origBlock);
-
- // Move the operations back the original block, move the original block back
- // into its original location and the delete the new block.
- origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
- block->erase();
-
- conversionInfo.erase(it);
-}
-
-void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
- for (auto &info : conversionInfo) {
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
- BlockArgument origArg = origBlock->getArgument(i);
-
- // Handle the case of a 1->0 value mapping.
- if (!argInfo) {
- if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = argInfo->castValue;
- assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- mapping.lookupOrDefault(castValue, origArg.getType()));
- }
- }
-
- delete origBlock;
- blockInfo.origBlock = nullptr;
- }
-}
-
-LogicalResult ArgConverter::materializeLiveConversions(
- ConversionValueMapping &mapping, OpBuilder &builder,
- function_ref<Operation *(Value)> findLiveUser) {
- for (auto &info : conversionInfo) {
- Block *newBlock = info.first;
- ConvertedBlockInfo &blockInfo = info.second;
- Block *origBlock = blockInfo.origBlock;
-
- // Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- BlockArgument origArg = origBlock->getArgument(i);
- if (mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriter.setInsertionPointToStart(newBlock);
- else
- rewriter.setInsertionPointAfterValue(replacementValue);
- Value newArg;
- if (blockInfo.converter) {
- newArg = blockInfo.converter->materializeSourceConversion(
- rewriter, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #" << i
- << " that remained live after conversion, type was "
- << origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- mapping.map(origArg, newArg);
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Conversion
-
-FailureOr<Block *> ArgConverter::convertSignature(
- Block *block, const TypeConverter *converter,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // Check if the block was already converted.
- // * If the block is mapped in `conversionInfo`, it is a converted block.
- // * If the block is detached, conservatively assume that it is going to be
- // deleted; it is likely the old block (before it was converted).
- if (conversionInfo.count(block) || !block->getParent())
- return block;
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(block, converter, *conversion, mapping,
- argReplacements);
- return failure();
-}
-
-Block *ArgConverter::applySignatureConversion(
- Block *block, const TypeConverter *converter,
- TypeConverter::SignatureConversion &signatureConversion,
- ConversionValueMapping &mapping,
- SmallVectorImpl<BlockArgument> &argReplacements) {
- // If no arguments are being changed or added, there is nothing to do.
- unsigned origArgCount = block->getNumArguments();
- auto convertedTypes = signatureConversion.getConvertedTypes();
- if (origArgCount == 0 && convertedTypes.empty())
- return block;
-
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = block->splitBlock(block->begin());
- block->replaceAllUsesWith(newBlock);
- // Unlink the block, but do not erase it yet, so that the change can be rolled
- // back.
- block->getParent()->getBlocks().remove(block);
-
- // Map all new arguments to the location of the argument they originate from.
- SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
- for (unsigned i = 0; i < origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap || inputMap->replacementValue)
- continue;
- Location origLoc = block->getArgument(i).getLoc();
- for (unsigned j = 0; j < inputMap->size; ++j)
- newLocs[inputMap->inputNo + j] = origLoc;
- }
-
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
-
- // Remap each of the original arguments as determined by the signature
- // conversion.
- ConvertedBlockInfo info(block, converter);
- info.argInfo.resize(origArgCount);
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
- for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
- BlockArgument origArg = block->getArgument(i);
-
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- argReplacements.push_back(origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
- // FIXME: We simply pass through the replacement argument if there wasn't a
- // converter, which isn't great as it allows implicit type conversions to
- // appear. We should properly restructure this code to handle cases where a
- // converter isn't provided and also to properly handle the case where an
- // argument materialization is actually a temporary source materialization
- // (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
-
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
-
- newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter, unresolvedMaterializations);
- }
-
- mapping.map(origArg, newArg);
- argReplacements.push_back(origArg);
- info.argInfo[i] =
- ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
- }
-
- conversionInfo.insert({newBlock, std::move(info)});
- return newBlock;
-}
//===----------------------------------------------------------------------===//
// IR rewrites
//===----------------------------------------------------------------------===//
-namespace {
/// An IR rewrite that can be committed (upon success) or rolled back (upon
/// failure).
///
@@ -685,19 +191,29 @@ public:
MoveBlock,
SplitBlock,
BlockTypeConversion,
+ ReplaceBlockArg,
// Operation rewrites
MoveOperation,
- ModifyOperation
+ ModifyOperation,
+ ReplaceOperation,
+ CreateOperation,
+ UnresolvedMaterialization
};
virtual ~IRRewrite() = default;
- /// Roll back the rewrite.
+ /// Roll back the rewrite. Operations may be erased during rollback.
virtual void rollback() = 0;
- /// Commit the rewrite.
+ /// Commit the rewrite. Operations may be unlinked from their blocks during
+ /// the commit phase, but they must not be erased yet. This is because
+ /// internal dialect conversion state (such as `mapping`) may still be using
+ /// them. Operations must be erased during cleanup.
virtual void commit() {}
+ /// Cleanup operations. Cleanup is called after commit.
+ virtual void cleanup() {}
+
Kind getKind() const { return kind; }
static bool classof(const IRRewrite *rewrite) { return true; }
@@ -706,6 +222,14 @@ protected:
IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
: kind(kind), rewriterImpl(rewriterImpl) {}
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op);
+
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block);
+
+ const ConversionConfig &getConfig() const;
+
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -718,7 +242,7 @@ public:
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::CreateBlock &&
- rewrite->getKind() <= Kind::BlockTypeConversion;
+ rewrite->getKind() <= Kind::ReplaceBlockArg;
}
protected:
@@ -748,8 +272,10 @@ public:
auto &blockOps = block->getOperations();
while (!blockOps.empty())
blockOps.remove(blockOps.begin());
- block->dropAllDefinedValueUses();
- block->erase();
+ if (block->getParent())
+ eraseBlock(block);
+ else
+ delete block;
}
};
@@ -787,6 +313,8 @@ public:
void commit() override {
// Erase the block.
assert(block && "expected block");
+ assert(block->empty() && "expected empty block");
+ block->dropAllDefinedValueUses();
delete block;
block = nullptr;
}
@@ -885,8 +413,7 @@ public:
// Merge back the block that was split out.
originalBlock->getOperations().splice(originalBlock->end(),
block->getOperations());
- block->dropAllDefinedValueUses();
- block->erase();
+ eraseBlock(block);
}
private:
@@ -894,20 +421,80 @@ private:
Block *originalBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {}
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- // TODO: Block type conversions are currently committed in
- // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
+ /// Materialize any necessary conversions for converted arguments that have
+ /// live users, using the provided `findLiveUser` to search for a user that
+ /// survives the conversion process.
+ LogicalResult
+ materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+
+ void commit() override;
+
void rollback() override;
+
+private:
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter *converter;
+};
+
+/// Replacing a block argument. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but the actual replacement is delayed
+/// until the rewrite is committed.
+class ReplaceBlockArgRewrite : public BlockRewrite {
+public:
+ ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceBlockArg;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+private:
+ BlockArgument arg;
};
/// An operation rewrite.
@@ -918,7 +505,7 @@ public:
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() >= Kind::MoveOperation &&
- rewrite->getKind() <= Kind::ModifyOperation;
+ rewrite->getKind() <= Kind::UnresolvedMaterialization;
}
protected:
@@ -953,8 +540,8 @@ private:
// The block in which this operation was previously contained.
Block *block;
- // The original successor of this operation before it was moved. "nullptr" if
- // this operation was the only operation in the region.
+ // The original successor of this operation before it was moved. "nullptr"
+ // if this operation was the only operation in the region.
Operation *insertBeforeOp;
};
@@ -1019,6 +606,118 @@ private:
SmallVector<Block *, 2> successors;
void *propertiesStorage = nullptr;
};
+
+/// Replacing an operation. Erasing an operation is treated as a special case
+/// with "null" replacements. This rewrite is not immediately reflected in the
+/// IR. An internal IR mapping is updated, but values are not replaced and the
+/// original op is not erased until the rewrite is committed.
+class ReplaceOperationRewrite : public OperationRewrite {
+public:
+ ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op, const TypeConverter *converter,
+ bool changedResults)
+ : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
+ converter(converter), changedResults(changedResults) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::ReplaceOperation;
+ }
+
+ void commit() override;
+
+ void rollback() override;
+
+ void cleanup() override;
+
+ const TypeConverter *getConverter() const { return converter; }
+
+ bool hasChangedResults() const { return changedResults; }
+
+private:
+ /// An optional type converter that can be used to materialize conversions
+ /// between the new and old values if necessary.
+ const TypeConverter *converter;
+
+ /// A boolean flag that indicates whether result types have changed or not.
+ bool changedResults;
+};
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+ CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Operation *op)
+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::CreateOperation;
+ }
+
+ void rollback() override;
+};
+
+/// The type of materialization.
+enum MaterializationKind {
+ /// This materialization materializes a conversion for an illegal block
+ /// argument type, to a legal one.
+ Argument,
+
+ /// This materialization materializes a conversion from an illegal type to a
+ /// legal one.
+ Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+ UnresolvedMaterializationRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+ MaterializationKind kind = MaterializationKind::Target,
+ Type origOutputType = nullptr)
+ : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+ converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+ static bool classof(const IRRewrite *rewrite) {
+ return rewrite->getKind() == Kind::UnresolvedMaterialization;
+ }
+
+ UnrealizedConversionCastOp getOperation() const {
+ return cast<UnrealizedConversionCastOp>(op);
+ }
+
+ void rollback() override;
+
+ void cleanup() override;
+
+ /// Return the type converter of this materialization (which may be null).
+ const TypeConverter *getConverter() const {
+ return converterAndKind.getPointer();
+ }
+
+ /// Return the kind of this materialization.
+ MaterializationKind getMaterializationKind() const {
+ return converterAndKind.getInt();
+ }
+
+ /// Set the kind of this materialization.
+ void setMaterializationKind(MaterializationKind kind) {
+ converterAndKind.setInt(kind);
+ }
+
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
+
+private:
+ /// The corresponding type converter to use when resolving this
+ /// materialization, and the kind of this materialization.
+ llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ converterAndKind;
+
+ /// The original output type. This is only used for argument conversions.
+ Type origOutputType;
+};
} // namespace
/// Return "true" if there is an operation rewrite that matches the specified
@@ -1031,23 +730,35 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
+/// Find the single rewrite object of the specified type and block among the
+/// given rewrites. In debug mode, asserts that there is mo more than one such
+/// object. Return "nullptr" if no object was found.
+template <typename RewriteTy, typename R>
+static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
+ RewriteTy *result = nullptr;
+ for (auto &rewrite : rewrites) {
+ auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+ if (rewriteTy && rewriteTy->getBlock() == block) {
+#ifndef NDEBUG
+ assert(!result && "expected single matching rewrite");
+ result = rewriteTy;
+#else
+ return rewriteTy;
+#endif // NDEBUG
+ }
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
- : argConverter(rewriter, unresolvedMaterializations),
- notifyCallback(nullptr) {}
-
- /// Cleanup and destroy any generated rewrite operations. This method is
- /// invoked when the conversion process fails.
- void discardRewrites();
-
- /// Apply all requested operation rewrites. This method is invoked when the
- /// conversion process succeeds.
- void applyRewrites();
+ explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ const ConversionConfig &config)
+ : eraseRewriter(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -1056,6 +767,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return the current state of the rewriter.
RewriterState getCurrentState();
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
/// Reset the state of the rewriter to a previously saved point.
void resetState(RewriterState state);
@@ -1092,11 +807,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//
- /// Convert the signature of the given block.
+ /// Attempt to convert the signature of the given block, if successful a new
+ /// block is returned containing the new arguments. Returns `block` if it did
+ /// not require conversion.
FailureOr<Block *> convertBlockSignature(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);
+ /// Convert the types of non-entry block arguments within the given region.
+ LogicalResult convertNonEntryRegionTypes(
+ Region *region, const TypeConverter &converter,
+ ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+
/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
@@ -1109,10 +831,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
- /// Convert the types of non-entry block arguments within the given region.
- LogicalResult convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
+ /// Apply the given signature conversion on the given block. The new block
+ /// containing the updated signature is returned. If no conversions were
+ /// necessary, e.g. if the block has no arguments, `block` is returned.
+ /// `converter` is used to generate any necessary cast operations that
+ /// translate between the origin argument types and those specified in the
+ /// signature conversion.
+ Block *applySignatureConversion(
+ Block *block, const TypeConverter *converter,
+ TypeConverter::SignatureConversion &signatureConversion);
+
+ //===--------------------------------------------------------------------===//
+ // Materializations
+ //===--------------------------------------------------------------------===//
+ /// Build an unresolved materialization operation given an output type and set
+ /// of input operands.
+ Value buildUnresolvedMaterialization(MaterializationKind kind,
+ Block *insertBlock,
+ Block::iterator insertPt, Location loc,
+ ValueRange inputs, Type outputType,
+ Type origOutputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
+ ValueRange inputs,
+ Type origOutputType,
+ Type outputType,
+ const TypeConverter *converter);
+
+ Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+ Type outputType,
+ const TypeConverter *converter);
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1145,28 +894,51 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
function_ref<void(Diagnostic &)> reasonCallback) override;
//===--------------------------------------------------------------------===//
- // State
+ // IR Erasure
//===--------------------------------------------------------------------===//
- // Mapping between replaced values that differ in type. This happens when
- // replacing a value with one of a different type.
- ConversionValueMapping mapping;
+ /// A rewriter that keeps track of erased ops and blocks. It ensures that no
+ /// operation or block is erased multiple times. This rewriter assumes that
+ /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
+ struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
+ public:
+ SingleEraseRewriter(MLIRContext *context)
+ : RewriterBase(context, /*listener=*/this) {}
+
+ /// Erase the given op (unless it was already erased).
+ void eraseOp(Operation *op) override {
+ if (erased.contains(op))
+ return;
+ op->dropAllUses();
+ RewriterBase::eraseOp(op);
+ }
- /// Utility used to convert block arguments.
- ArgConverter argConverter;
+ /// Erase the given block (unless it was already erased).
+ void eraseBlock(Block *block) override {
+ if (erased.contains(block))
+ return;
+ assert(block->empty() && "expected empty block");
+ block->dropAllDefinedValueUses();
+ RewriterBase::eraseBlock(block);
+ }
- /// Ordered vector of all of the newly created operations during conversion.
- SmallVector<Operation *> createdOps;
+ void notifyOperationErased(Operation *op) override { erased.insert(op); }
+ void notifyBlockErased(Block *block) override { erased.insert(block); }
- /// Ordered vector of all unresolved type conversion materializations during
- /// conversion.
- SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
+ /// Pointers to all erased operations and blocks.
+ SetVector<void *> erased;
+ };
+
+ //===--------------------------------------------------------------------===//
+ // State
+ //===--------------------------------------------------------------------===//
- /// Ordered map of requested operation replacements.
- llvm::MapVector<Operation *, OpReplacement> replacements;
+ /// This rewriter must be used for erasing ops/blocks.
+ SingleEraseRewriter eraseRewriter;
- /// Ordered vector of any requested block argument replacements.
- SmallVector<BlockArgument, 4> argReplacements;
+ // Mapping between replaced values that differ in type. This happens when
+ // replacing a value with one of a different type.
+ ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
@@ -1182,11 +954,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// operation was ignored.
SetVector<Operation *> ignoredOps;
- /// A vector of indices into `replacements` of operations that were replaced
- /// with values with different result types than the original operation, e.g.
- /// 1->N conversion of some kind.
- SmallVector<unsigned, 4> operationsWithChangedResults;
-
/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
@@ -1195,8 +962,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;
- /// This allows the user to collect the match failure message.
- function_ref<void(Diagnostic &)> notifyCallback;
+ /// Dialect conversion configuration.
+ const ConversionConfig &config;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
@@ -1211,154 +978,193 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
} // namespace detail
} // namespace mlir
-void BlockTypeConversionRewrite::rollback() {
- // Undo the type conversion.
- rewriterImpl.argConverter.discardRewrites(block);
-}
-
-/// Detach any operations nested in the given operation from their parent
-/// blocks, and erase the given operation. This can be used when the nested
-/// operations are scheduled for erasure themselves, so deleting the regions of
-/// the given operation together with their content would result in double-free.
-/// This happens, for example, when rolling back op creation in the reverse
-/// order and if the nested ops were created before the parent op. This function
-/// does not need to collect nested ops recursively because it is expected to
-/// also be called for each nested op when it is about to be deleted.
-static void detachNestedAndErase(Operation *op) {
- for (Region &region : op->getRegions()) {
- for (Block &block : region.getBlocks()) {
- while (!block.getOperations().empty())
- block.getOperations().remove(block.getOperations().begin());
- block.dropAllDefinedValueUses();
+void IRRewrite::eraseOp(Operation *op) {
+ rewriterImpl.eraseRewriter.eraseOp(op);
+}
+
+void IRRewrite::eraseBlock(Block *block) {
+ rewriterImpl.eraseRewriter.eraseBlock(block);
+}
+
+const ConversionConfig &IRRewrite::getConfig() const {
+ return rewriterImpl.config;
+}
+
+void BlockTypeConversionRewrite::commit() {
+ // Process the remapping for each of the original arguments.
+ for (auto [origArg, info] :
+ llvm::zip_equal(origBlock->getArguments(), argInfo)) {
+ // Handle the case of a 1->0 value mapping.
+ if (!info) {
+ if (Value newArg =
+ rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ origArg.replaceAllUsesWith(newArg);
+ continue;
+ }
+
+ // Otherwise this is a 1->1+ value mapping.
+ Value castValue = info->castValue;
+ assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
+
+ // If the argument is still used, replace it with the generated cast.
+ if (!origArg.use_empty()) {
+ origArg.replaceAllUsesWith(
+ rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
}
}
- op->dropAllUses();
- op->erase();
+
+ assert(origBlock->empty() && "expected empty block");
+ origBlock->dropAllDefinedValueUses();
+ delete origBlock;
+ origBlock = nullptr;
}
-void ConversionPatternRewriterImpl::discardRewrites() {
- undoRewrites();
+void BlockTypeConversionRewrite::rollback() {
+ // Drop all uses of the new block arguments and replace uses of the new block.
+ for (int i = block->getNumArguments() - 1; i >= 0; --i)
+ block->getArgument(i).dropAllUses();
+ block->replaceAllUsesWith(origBlock);
- // Remove any newly created ops.
- for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
- detachNestedAndErase(materialization.getOp());
- for (auto *op : llvm::reverse(createdOps))
- detachNestedAndErase(op);
+ // Move the operations back the original block, move the original block back
+ // into its original location and the delete the new block.
+ origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+ block->getParent()->getBlocks().insert(Region::iterator(block), origBlock);
+ eraseBlock(block);
}
-void ConversionPatternRewriterImpl::applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (OpResult result : repl.first->getResults())
- if (Value newValue = mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
-
- // If this operation defines any regions, drop any pending argument
- // rewrites.
- if (repl.first->getNumRegions())
- argConverter.notifyOpRemoved(repl.first);
- }
-
- // Apply all of the requested argument replacements.
- for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrNull(arg, arg.getType());
- if (!repl)
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+ function_ref<Operation *(Value)> findLiveUser) {
+ // Process the remapping for each of the original arguments.
+ for (auto it : llvm::enumerate(origBlock->getArguments())) {
+ BlockArgument origArg = it.value();
+ // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
+ OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
+ builder.setInsertionPointToStart(block);
+
+ // If the type of this argument changed and the argument is still live, we
+ // need to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
continue;
-
- if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
+ Operation *liveUser = findLiveUser(origArg);
+ if (!liveUser)
continue;
+
+ Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+ bool isDroppedArg = replacementValue == origArg;
+ if (!isDroppedArg)
+ builder.setInsertionPointAfterValue(replacementValue);
+ Value newArg;
+ if (converter) {
+ newArg = converter->materializeSourceConversion(
+ builder, origArg.getLoc(), origArg.getType(),
+ isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+ assert((!newArg || newArg.getType() == origArg.getType()) &&
+ "materialization hook did not provide a value of the expected "
+ "type");
}
+ if (!newArg) {
+ InFlightDiagnostic diag =
+ emitError(origArg.getLoc())
+ << "failed to materialize conversion for block argument #"
+ << it.index() << " that remained live after conversion, type was "
+ << origArg.getType();
+ if (!isDroppedArg)
+ diag << ", with target type " << replacementValue.getType();
+ diag.attachNote(liveUser->getLoc())
+ << "see existing live user here: " << *liveUser;
+ return failure();
+ }
+ rewriterImpl.mapping.map(origArg, newArg);
+ }
+ return success();
+}
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = cast<OpResult>(repl).getOwner();
- Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- });
+void ReplaceBlockArgRewrite::commit() {
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
+ if (!repl)
+ return;
+
+ if (isa<BlockArgument>(repl)) {
+ arg.replaceAllUsesWith(repl);
+ return;
}
- // Drop all of the unresolved materialization operations created during
- // conversion.
- for (auto &mat : unresolvedMaterializations) {
- mat.getOp()->dropAllUses();
- mat.getOp()->erase();
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = cast<OpResult>(repl).getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+}
+
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+
+void ReplaceOperationRewrite::commit() {
+ for (OpResult result : op->getResults())
+ if (Value newValue =
+ rewriterImpl.mapping.lookupOrNull(result, result.getType()))
+ result.replaceAllUsesWith(newValue);
+ if (getConfig().unlegalizedOps)
+ getConfig().unlegalizedOps->erase(op);
+ // Do not erase the operation yet. It may still be referenced in `mapping`.
+ op->getBlock()->getOperations().remove(op);
+}
+
+void ReplaceOperationRewrite::rollback() {
+ for (auto result : op->getResults())
+ rewriterImpl.mapping.erase(result);
+}
+
+void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+
+void CreateOperationRewrite::rollback() {
+ for (Region &region : op->getRegions()) {
+ while (!region.getBlocks().empty())
+ region.getBlocks().remove(region.getBlocks().begin());
}
+ op->dropAllUses();
+ eraseOp(op);
+}
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed. Because we process in reverse order, producers may be deleted
- // before their users (a pattern deleting a producer and then the consumer)
- // so we first drop all uses explicitly.
- for (auto &repl : llvm::reverse(replacements)) {
- repl.first->dropAllUses();
- repl.first->erase();
+void UnresolvedMaterializationRewrite::rollback() {
+ if (getMaterializationKind() == MaterializationKind::Target) {
+ for (Value input : op->getOperands())
+ rewriterImpl.mapping.erase(input);
}
+ eraseOp(op);
+}
- argConverter.applyRewrites(mapping);
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
for (auto &rewrite : rewrites)
rewrite->commit();
+ for (auto &rewrite : rewrites)
+ rewrite->cleanup();
}
//===----------------------------------------------------------------------===//
// State Management
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
- replacements.size(), argReplacements.size(),
- rewrites.size(), ignoredOps.size());
+ return RewriterState(rewrites.size(), ignoredOps.size(),
+ eraseRewriter.erased.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
- // Reset any replaced arguments.
- for (BlockArgument replacedArg :
- llvm::drop_begin(argReplacements, state.numArgReplacements))
- mapping.erase(replacedArg);
- argReplacements.resize(state.numArgReplacements);
-
// Undo any rewrites.
undoRewrites(state.numRewrites);
- // Reset any replaced operations and undo any saved mappings.
- for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
- for (auto result : repl.first->getResults())
- mapping.erase(result);
- while (replacements.size() != state.numReplacements)
- replacements.pop_back();
-
- // Pop all of the newly inserted materializations.
- while (unresolvedMaterializations.size() !=
- state.numUnresolvedMaterializations) {
- UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
- UnrealizedConversionCastOp op = mat.getOp();
-
- // If this was a target materialization, drop the mapping that was inserted.
- if (mat.getKind() == UnresolvedMaterialization::Target) {
- for (Value input : op->getOperands())
- mapping.erase(input);
- }
- detachNestedAndErase(op);
- }
-
- // Pop all of the newly created operations.
- while (createdOps.size() != state.numCreatedOps) {
- detachNestedAndErase(createdOps.back());
- createdOps.pop_back();
- }
-
// Pop all of the recorded ignored operations that are no longer valid.
while (ignoredOps.size() != state.numIgnoredOperations)
ignoredOps.pop_back();
- // Reset operations with changed results.
- while (!operationsWithChangedResults.empty() &&
- operationsWithChangedResults.back() >= state.numReplacements)
- operationsWithChangedResults.pop_back();
+ while (eraseRewriter.erased.size() != state.numErased)
+ eraseRewriter.erased.pop_back();
}
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1413,8 +1219,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
Value castValue = buildUnresolvedTargetMaterialization(
- operandLoc, newOperand, desiredType, currentTypeConverter,
- unresolvedMaterializations);
+ operandLoc, newOperand, desiredType, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1425,7 +1230,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation was replaced or its parent ignored.
- return replacements.count(op) || ignoredOps.count(op->getParentOp());
+ return ignoredOps.count(op->getParentOp()) ||
+ hasRewrite<ReplaceOperationRewrite>(rewrites, op);
}
void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1447,18 +1253,18 @@ void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
- FailureOr<Block *> result =
- conversion ? argConverter.applySignatureConversion(
- block, converter, *conversion, mapping, argReplacements)
- : argConverter.convertSignature(block, converter, mapping,
- argReplacements);
- if (failed(result))
+ if (conversion)
+ return applySignatureConversion(block, converter, *conversion);
+
+ // If a converter wasn't provided, and the block wasn't already converted,
+ // there is nothing we can do.
+ if (!converter)
return failure();
- if (Block *newBlock = *result) {
- if (newBlock != block)
- appendRewrite<BlockTypeConversionRewrite>(newBlock);
- }
- return result;
+
+ // Try to convert the signature for the block with the provided converter.
+ if (auto conversion = converter->convertBlockSignature(block))
+ return applySignatureConversion(block, converter, *conversion);
+ return failure();
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1512,6 +1318,145 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
return success();
}
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
+ Block *block, const TypeConverter *converter,
+ TypeConverter::SignatureConversion &signatureConversion) {
+ MLIRContext *ctx = eraseRewriter.getContext();
+
+ // If no arguments are being changed or added, there is nothing to do.
+ unsigned origArgCount = block->getNumArguments();
+ auto convertedTypes = signatureConversion.getConvertedTypes();
+ if (llvm::equal(block->getArgumentTypes(), convertedTypes))
+ return block;
+
+ // Split the block at the beginning to get a new block to use for the updated
+ // signature.
+ Block *newBlock = block->splitBlock(block->begin());
+ block->replaceAllUsesWith(newBlock);
+ // Unlink the block, but do not erase it yet, so that the change can be rolled
+ // back.
+ block->getParent()->getBlocks().remove(block);
+
+ // Map all new arguments to the location of the argument they originate from.
+ SmallVector<Location> newLocs(convertedTypes.size(),
+ Builder(ctx).getUnknownLoc());
+ for (unsigned i = 0; i < origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap || inputMap->replacementValue)
+ continue;
+ Location origLoc = block->getArgument(i).getLoc();
+ for (unsigned j = 0; j < inputMap->size; ++j)
+ newLocs[inputMap->inputNo + j] = origLoc;
+ }
+
+ SmallVector<Value, 4> newArgRange(
+ newBlock->addArguments(convertedTypes, newLocs));
+ ArrayRef<Value> newArgs(newArgRange);
+
+ // Remap each of the original arguments as determined by the signature
+ // conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+ argInfo.resize(origArgCount);
+
+ for (unsigned i = 0; i != origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap)
+ continue;
+ BlockArgument origArg = block->getArgument(i);
+
+ // If inputMap->replacementValue is not nullptr, then the argument is
+ // dropped and a replacement value is provided to be the remappedValue.
+ if (inputMap->replacementValue) {
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, inputMap->replacementValue);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ continue;
+ }
+
+ // Otherwise, this is a 1->1+ mapping.
+ auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ Value newArg;
+
+ // If this is a 1->1 mapping and the types of new and replacement arguments
+ // match (i.e. it's an identity map), then the argument is mapped to its
+ // original type.
+ // FIXME: We simply pass through the replacement argument if there wasn't a
+ // converter, which isn't great as it allows implicit type conversions to
+ // appear. We should properly restructure this code to handle cases where a
+ // converter isn't provided and also to properly handle the case where an
+ // argument materialization is actually a temporary source materialization
+ // (e.g. in the case of 1->N).
+ if (replArgs.size() == 1 &&
+ (!converter || replArgs[0].getType() == origArg.getType())) {
+ newArg = replArgs.front();
+ } else {
+ Type origOutputType = origArg.getType();
+
+ // Legalize the argument output type.
+ Type outputType = origOutputType;
+ if (Type legalOutputType = converter->convertType(outputType))
+ outputType = legalOutputType;
+
+ newArg = buildUnresolvedArgumentMaterialization(
+ newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
+ converter);
+ }
+
+ mapping.map(origArg, newArg);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
+ }
+
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
+ converter);
+ return newBlock;
+}
+
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+ Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ const TypeConverter *converter) {
+ // Avoid materializing an unnecessary cast.
+ if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ return inputs.front();
+
+ // Create an unresolved materialization. We use a new OpBuilder to avoid
+ // tracking the materialization like we do for other operations.
+ OpBuilder builder(insertBlock, insertPt);
+ auto convertOp =
+ builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ origOutputType);
+ return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+ Block *block, Location loc, ValueRange inputs, Type origOutputType,
+ Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
+ block->begin(), loc, inputs, outputType,
+ origOutputType, converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+ Location loc, Value input, Type outputType,
+ const TypeConverter *converter) {
+ Block *insertBlock = input.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(input))
+ insertPt = ++inputRes.getOwner()->getIterator();
+
+ return buildUnresolvedMaterialization(MaterializationKind::Target,
+ insertBlock, insertPt, loc, input,
+ outputType, outputType, converter);
+}
+
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1523,7 +1468,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
});
if (!previous.isSet()) {
// This is a newly created op.
- createdOps.push_back(op);
+ appendRewrite<CreateOperationRewrite>(op);
return;
}
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1535,7 +1480,12 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
- assert(!replacements.count(op) && "operation was already replaced");
+#ifndef NDEBUG
+ for (auto &rewrite : rewrites)
+ if (auto *opReplacement = dyn_cast<ReplaceOperationRewrite>(rewrite.get()))
+ assert(opReplacement->getOperation() != op &&
+ "operation was already replaced");
+#endif // NDEBUG
// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;
@@ -1550,11 +1500,9 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
}
- if (resultChanged)
- operationsWithChangedResults.push_back(replacements.size());
- // Record the requested operation replacement.
- replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
+ appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
+ resultChanged);
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
@@ -1594,8 +1542,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
- if (notifyCallback)
- notifyCallback(diag);
+ if (config.notifyCallback)
+ config.notifyCallback(diag);
});
}
@@ -1603,9 +1551,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ConversionPatternRewriter::ConversionPatternRewriter(
+ MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this)) {
+ impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
setListener(impl.get());
}
@@ -1649,8 +1598,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
- impl->notifyBlockIsBeingErased(block);
-
// Mark all ops for erasure.
for (Operation &op : *block)
eraseOp(&op);
@@ -1659,6 +1606,7 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
// object and will be actually destroyed when rewrites are applied. This
// allows us to keep the operations in the block live and undo the removal by
// re-inserting the block.
+ impl->notifyBlockIsBeingErased(block);
block->getParent()->getBlocks().remove(block);
}
@@ -1688,7 +1636,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->argReplacements.push_back(from);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2022,13 +1970,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+ for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
i != e; ++i) {
- Operation *cstOp = rewriterImpl.createdOps[i];
- if (failed(legalize(cstOp, rewriter))) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ if (failed(legalize(createOp->getOperation(), rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- cstOp->getName()));
+ createOp->getOperation()->getName()));
rewriterImpl.resetState(curState);
return failure();
}
@@ -2054,12 +2005,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
- if (rewriterImpl.notifyCallback) {
+ if (rewriterImpl.config.notifyCallback) {
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
diag << "Failed to apply pattern \"" << pattern.getDebugName()
<< "\" on op:\n"
<< *op;
- rewriterImpl.notifyCallback(diag);
+ rewriterImpl.config.notifyCallback(diag);
}
});
rewriterImpl.resetState(curState);
@@ -2112,16 +2063,13 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
#ifndef NDEBUG
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-
// Check that the root was either replaced or updated in place.
+ auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites);
auto replacedRoot = [&] {
- return llvm::any_of(
- llvm::drop_begin(impl.replacements, curState.numReplacements),
- [op](auto &it) { return it.first == op; });
+ return hasRewrite<ReplaceOperationRewrite>(newRewrites, op);
};
auto updatedRootInPlace = [&] {
- return hasRewrite<ModifyOperationRewrite>(
- llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
+ return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
assert((replacedRoot() || updatedRootInPlace()) &&
"expected pattern to replace the root operation");
@@ -2154,7 +2102,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
if (!rewrite)
continue;
Block *block = rewrite->getBlock();
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite))
+ if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
+ ReplaceBlockArgRewrite>(rewrite))
continue;
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
@@ -2177,9 +2126,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// blocks in regions created by this pattern will already be legalized later
// on. If we haven't built the set yet, build it now.
if (operationsToIgnore.empty()) {
- auto createdOps = ArrayRef<Operation *>(impl.createdOps)
- .drop_front(state.numCreatedOps);
- operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+ for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+ ++i) {
+ auto *createOp =
+ dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ operationsToIgnore.insert(createOp->getOperation());
+ }
}
// If this operation should be considered for re-legalization, try it.
@@ -2197,8 +2151,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
- for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
- Operation *op = impl.createdOps[i];
+ for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+ if (!createOp)
+ continue;
+ Operation *op = createOp->getOperation();
if (failed(legalize(op, rewriter))) {
LLVM_DEBUG(logFailure(impl.logger,
"failed to legalize generated operation '{0}'({1})",
@@ -2432,21 +2389,21 @@ enum OpConversionMode {
/// applied to the operations on success.
Analysis,
};
+} // namespace
+namespace mlir {
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
explicit OperationConverter(const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- OpConversionMode mode,
- DenseSet<Operation *> *trackedOps = nullptr)
- : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
+ const ConversionConfig &config,
+ OpConversionMode mode)
+ : opLegalizer(target, patterns), config(config), mode(mode) {}
/// Converts the given operations to the conversion target.
- LogicalResult
- convertOperations(ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+ LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
@@ -2483,16 +2440,13 @@ private:
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
+ /// Dialect conversion configuration.
+ ConversionConfig config;
+
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
-
- /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
- /// this is populated with ops found to be legalizable to the target.
- /// When mode == OpConversionMode::Partial, this is populated with ops found
- /// *not* to be legalizable to the target.
- DenseSet<Operation *> *trackedOps;
};
-} // namespace
+} // namespace mlir
LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
Operation *op) {
@@ -2504,28 +2458,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
// Partial conversions allow conversions to fail iff the operation was not
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
- // set, non-legalizable ops are included.
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
+ // set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
- if (trackedOps)
- trackedOps->insert(op);
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- trackedOps->insert(op);
+ if (config.legalizableOps)
+ config.legalizableOps->insert(op);
}
return success();
}
-LogicalResult OperationConverter::convertOperations(
- ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2546,33 +2499,25 @@ LogicalResult OperationConverter::convertOperations(
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext());
+ ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- rewriterImpl.notifyCallback = notifyCallback;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// Now that all of the operations have been converted, finalize the conversion
// process to ensure any lingering conversion artifacts are cleaned up and
// legalized.
if (failed(finalize(rewriter)))
- return rewriterImpl.discardRewrites(), failure();
+ return rewriterImpl.undoRewrites(), failure();
// After a successful conversion, apply rewrites if this is not an analysis
// conversion.
if (mode == OpConversionMode::Analysis) {
- rewriterImpl.discardRewrites();
+ rewriterImpl.undoRewrites();
} else {
rewriterImpl.applyRewrites();
-
- // It is possible for a later pattern to erase an op that was originally
- // identified as illegal and added to the trackedOps, remove it now after
- // replacements have been computed.
- if (trackedOps)
- for (auto &repl : rewriterImpl.replacements)
- trackedOps->erase(repl.first);
}
return success();
}
@@ -2586,21 +2531,20 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
- if (rewriterImpl.operationsWithChangedResults.empty())
- return success();
-
// Process requested operation replacements.
- for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size();
- i != e; ++i) {
- unsigned replIdx = rewriterImpl.operationsWithChangedResults[i];
- auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
- for (OpResult result : repl.first->getResults()) {
+ for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+ auto *opReplacement =
+ dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
+ if (!opReplacement || !opReplacement->hasChangedResults())
+ continue;
+ Operation *op = opReplacement->getOperation();
+ for (OpResult result : op->getResults()) {
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
if (!newValue) {
- if (failed(legalizeErasedResult(repl.first, result, rewriterImpl)))
+ if (failed(legalizeErasedResult(op, result, rewriterImpl)))
return failure();
continue;
}
@@ -2614,15 +2558,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
inverseMapping = rewriterImpl.mapping.getInverse();
// Legalize this result.
- rewriter.setInsertionPoint(repl.first);
- if (failed(legalizeChangedResultType(repl.first, result, newValue,
- repl.second.converter, rewriter,
- rewriterImpl, *inverseMapping)))
+ rewriter.setInsertionPoint(op);
+ if (failed(legalizeChangedResultType(
+ op, result, newValue, opReplacement->getConverter(), rewriter,
+ rewriterImpl, *inverseMapping)))
return failure();
-
- // Update the end iterator for this loop in the case it was updated
- // when legalizing generated conversion operations.
- e = rewriterImpl.operationsWithChangedResults.size();
}
}
return success();
@@ -2639,8 +2579,17 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
});
return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
};
- return rewriterImpl.argConverter.materializeLiveConversions(
- rewriterImpl.mapping, rewriter, findLiveUser);
+ // Note: `rewrites` may be reallocated as the loop is running.
+ for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
+ ++i) {
+ auto &rewrite = rewriterImpl.rewrites[i];
+ if (auto *blockTypeConversionRewrite =
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+ if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+ findLiveUser)))
+ return failure();
+ }
+ return success();
}
/// Replace the results of a materialization operation with the given values.
@@ -2672,11 +2621,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
/// Compute all of the unresolved materializations that will persist beyond the
/// conversion process, and require inserting a proper user materialization for.
static void computeNecessaryMaterializations(
- DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ &materializationOps,
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping,
- SetVector<UnresolvedMaterialization *> &necessaryMaterializations) {
+ SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
auto isLive = [&](Value value) {
auto findFn = [&](Operation *user) {
auto matIt = materializationOps.find(user);
@@ -2711,14 +2661,17 @@ static void computeNecessaryMaterializations(
return Value();
};
- SetVector<UnresolvedMaterialization *> worklist;
- for (auto &mat : rewriterImpl.unresolvedMaterializations) {
- materializationOps.try_emplace(mat.getOp(), &mat);
- worklist.insert(&mat);
+ SetVector<UnresolvedMaterializationRewrite *> worklist;
+ for (auto &rewrite : rewriterImpl.rewrites) {
+ auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+ if (!mat)
+ continue;
+ materializationOps.try_emplace(mat->getOperation(), mat);
+ worklist.insert(mat);
}
while (!worklist.empty()) {
- UnresolvedMaterialization *mat = worklist.pop_back_val();
- UnrealizedConversionCastOp op = mat->getOp();
+ UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
+ UnrealizedConversionCastOp op = mat->getOperation();
// We currently only handle target materializations here.
assert(op->getNumResults() == 1 && "unexpected materialization type");
@@ -2760,7 +2713,7 @@ static void computeNecessaryMaterializations(
auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); };
if (llvm::any_of(op->getOperands(), isBlockArg) ||
llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) {
- mat->setKind(UnresolvedMaterialization::Argument);
+ mat->setMaterializationKind(MaterializationKind::Argument);
}
// If the materialization does not have any live users, we don't need to
@@ -2770,7 +2723,7 @@ static void computeNecessaryMaterializations(
// value replacement even if the types differ in some cases. When those
// patterns are fixed, we can drop the argument special case here.
bool isMaterializationLive = isLive(opResult);
- if (mat->getKind() == UnresolvedMaterialization::Argument)
+ if (mat->getMaterializationKind() == MaterializationKind::Argument)
isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive);
if (!isMaterializationLive)
continue;
@@ -2790,8 +2743,9 @@ static void computeNecessaryMaterializations(
/// Legalize the given unresolved materialization. Returns success if the
/// materialization was legalized, failure otherise.
static LogicalResult legalizeUnresolvedMaterialization(
- UnresolvedMaterialization &mat,
- DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps,
+ UnresolvedMaterializationRewrite &mat,
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *>
+ &materializationOps,
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping) {
@@ -2811,7 +2765,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
return Value();
};
- UnrealizedConversionCastOp op = mat.getOp();
+ UnrealizedConversionCastOp op = mat.getOperation();
if (!rewriterImpl.ignoredOps.insert(op))
return success();
@@ -2861,8 +2815,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
rewriter.setInsertionPoint(op);
Value newMaterialization;
- switch (mat.getKind()) {
- case UnresolvedMaterialization::Argument:
+ switch (mat.getMaterializationKind()) {
+ case MaterializationKind::Argument:
// Try to materialize an argument conversion.
// FIXME: The current argument materialization hook expects the original
// output type, even though it doesn't use that as the actual output type
@@ -2879,7 +2833,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
// If an argument materialization failed, fallback to trying a target
// materialization.
[[fallthrough]];
- case UnresolvedMaterialization::Target:
+ case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
@@ -2907,14 +2861,12 @@ LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
- if (rewriterImpl.unresolvedMaterializations.empty())
- return success();
inverseMapping = rewriterImpl.mapping.getInverse();
// As an initial step, compute all of the inserted materializations that we
// expect to persist beyond the conversion process.
- DenseMap<Operation *, UnresolvedMaterialization *> materializationOps;
- SetVector<UnresolvedMaterialization *> necessaryMaterializations;
+ DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
+ SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
*inverseMapping, necessaryMaterializations);
@@ -3535,57 +3487,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
//===----------------------------------------------------------------------===//
// Partial Conversion
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
- unconvertedOps);
+LogicalResult mlir::applyPartialConversion(
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Partial);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
- unconvertedOps);
+ ConversionConfig config) {
+ return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Full Conversion
-LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Full);
return opConverter.convertOperations(ops);
}
-LogicalResult
-mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- return applyFullConversion(llvm::ArrayRef(op), target, patterns);
+LogicalResult mlir::applyFullConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Analysis Conversion
-LogicalResult
-mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
- ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
- &convertedOps);
- return opConverter.convertOperations(ops, notifyCallback);
+LogicalResult mlir::applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Analysis);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
- convertedOps, notifyCallback);
+ ConversionConfig config) {
+ return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
}
diff --git a/mlir/test/CAPI/llvm.c b/mlir/test/CAPI/llvm.c
index 5a78fac..1817988 100644
--- a/mlir/test/CAPI/llvm.c
+++ b/mlir/test/CAPI/llvm.c
@@ -15,6 +15,7 @@
#include "mlir-c/Support.h"
#include <assert.h>
+#include <inttypes.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
@@ -105,7 +106,7 @@ static int testStructTypeCreation(MlirContext ctx) {
// CHECK: i8
// CHECK: i32
// CHECK: i64
- fprintf(stderr, "num elements: %ld\n",
+ fprintf(stderr, "num elements: %" PRIdPTR "\n",
mlirLLVMStructTypeGetNumElementTypes(literal));
for (intptr_t i = 0; i < 3; ++i) {
mlirTypeDump(mlirLLVMStructTypeGetElementType(literal, i));
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 3de2f11d..56129db 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -77,6 +77,18 @@ func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @log1p_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @log1p_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[VEC]] : vector<[4]xf32>
+ // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ %0 = math.log1p %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1(
// CHECK-SAME: f32
func.func @expm1(%arg0 : f32) {
@@ -113,6 +125,18 @@ func.func @expm1_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @expm1_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @expm1_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[EXP:.*]] = llvm.intr.exp(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<[4]xf32>
+ %0 = math.expm1 %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @expm1_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) {
@@ -177,6 +201,16 @@ func.func @cttz_vec(%arg0 : vector<4xi32>) {
// -----
+// CHECK-LABEL: func @cttz_scalable_vec(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @cttz_scalable_vec(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: "llvm.intr.cttz"(%[[VEC]]) <{is_zero_poison = false}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.cttz %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @ctpop(
// CHECK-SAME: i32
func.func @ctpop(%arg0 : i32) {
@@ -197,6 +231,16 @@ func.func @ctpop_vector(%arg0 : vector<3xi32>) {
// -----
+// CHECK-LABEL: func @ctpop_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xi32>
+func.func @ctpop_scalable_vector(%arg0 : vector<[4]xi32>) -> vector<[4]xi32> {
+ // CHECK: llvm.intr.ctpop(%[[VEC]]) : (vector<[4]xi32>) -> vector<[4]xi32>
+ %0 = math.ctpop %arg0 : vector<[4]xi32>
+ func.return %0 : vector<[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_double(
// CHECK-SAME: f64
func.func @rsqrt_double(%arg0 : f64) {
@@ -233,6 +277,18 @@ func.func @rsqrt_vector(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector(%arg0 : vector<[4]xf32>) -> vector<[4]xf32>{
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_vector_fmf(
// CHECK-SAME: vector<4xf32>
func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
@@ -245,6 +301,18 @@ func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_scalable_vector_fmf(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
+func.func @rsqrt_scalable_vector_fmf(%arg0 : vector<[4]xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[VEC]]) {fastmathFlags = #llvm.fastmath<fast>} : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath<fast>} : vector<[4]xf32>
+ %0 = math.rsqrt %arg0 fastmath<fast> : vector<[4]xf32>
+ func.return %0 : vector<[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @rsqrt_multidim_vector(
func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
@@ -258,6 +326,19 @@ func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) {
// -----
+// CHECK-LABEL: func @rsqrt_multidim_scalable_vector(
+func.func @rsqrt_multidim_scalable_vector(%arg0 : vector<4x[4]xf32>) -> vector<4x[4]xf32> {
+ // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
+ // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%[[EXTRACT]]) : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] : vector<[4]xf32>
+ // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[DIV]], %{{.*}}[0] : !llvm.array<4 x vector<[4]xf32>>
+ %0 = math.rsqrt %arg0 : vector<4x[4]xf32>
+ func.return %0 : vector<4x[4]xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fpowi(
// CHECK-SAME: f64
func.func @fpowi(%arg0 : f64, %arg1 : i32) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index febe74e..1fa783f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -759,6 +759,21 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () {
// -----
+// CHECK-LABEL: @test_i64
+func.func @test_i64(%arg0: tensor<1xi64>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.+]]: i64,
+ // CHECK-DAG: %[[C127:.+]] = arith.constant -9223372036854775808
+ // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807
+ // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]]
+ // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]]
+ %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_clamp_f16
func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () {
// CHECK: linalg.generic
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7adde31..206d7e9 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten
// -----
// CHECK-LABEL: func @linalg_effects(
-// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
-// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>
-// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32>
-func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
+func.func @linalg_effects(
+ %a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
+ %d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
// CHECK-NOT: %{{.*}} = linalg.matmul
- %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
+ %t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: linalg.matmul
- linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%b : memref<?x?xf32>)
+ linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
+ outs(%f : memref<?x?xf32>)
return
}
@@ -889,11 +888,11 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
// -----
#map = affine_map<(d0) -> (d0)>
-func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
+func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
linalg.generic {
indexing_maps = [#map, #map],
iterator_types = ["parallel"]
- } ins(%arg0 : tensor<?xf32>)
+ } ins(%arg0 : memref<?xf32>)
outs(%arg1 : memref<?xf32>) {
^bb0(%arg2 : f32, %arg3 : f32):
linalg.yield %arg2 : f32
@@ -901,14 +900,13 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
return
}
-// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
-// For now, check generic remained unchanged.
-// CHECK-LABEL: func @identity_mixed
-// CHECK-SAME: (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
+// Do not erase ops with buffer semantics.
+// CHECK-LABEL: func @identity_buffer
+// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#map, #map],
// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>)
+// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>)
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
// -----
@@ -916,12 +914,12 @@ func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
// Just make sure that we don't crash.
// CHECK-LABEL: func @dedeplicate_regression_test
-func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
+func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
%36 = linalg.generic
{indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
- ins(%1, %1 : memref<4xf32>, memref<4xf32>)
+ ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
outs(%0 : tensor<4xf32>) {
^bb0(%in: f32, %in_24: f32, %out: f32):
linalg.yield %in : f32
@@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
// -----
-#map = affine_map<(d0) -> (d0)>
-func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
- %0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
- linalg.generic {
- indexing_maps = [#map, #map],
- iterator_types = ["parallel"]
- } ins(%0 : tensor<?xf32>)
- outs(%arg1 : memref<?xf32>) {
- ^bb0(%arg2 : f32, %arg3 : f32):
- linalg.yield %arg2 : f32
- }
- return
-}
-
-// We need a mixed linalg as a bridge between tensor and memref worlds.
-// CHECK-LABEL: func @cast_producer_mixed
-// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
-// CHECK: linalg.generic {
-// CHECK-SAME: indexing_maps = [#map, #map],
-// CHECK-SAME: iterator_types = ["parallel"]
-// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
-// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
-
-// -----
-
// CHECK-LABEL: dead_softmax
func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
%0 = tensor.empty() : tensor<16x64x256xf32>
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421c..15a4f6c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1110,43 +1110,3 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]
-
-// -----
-
-// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
-#map0 = affine_map<(d0, d1) -> (d0, d1)>
-
-// CHECK-LABEL: @mixed_fusion
-func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
-{
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
- %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
- %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
- %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%2 : tensor<?x?xf32>) {
- ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
- %4 = arith.addf %arg3, %arg4 : f32
- linalg.yield %4 : f32
- } -> tensor<?x?xf32>
- // CHECK: linalg.generic {
- // CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
- linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
- ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
- outs(%arg8 : memref<?x?xf32>) {
- // CHECK: ^{{[a-zA-Z0-9_]*}}
- // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
- // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
- // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
- ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
- // CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
- // CHECK-NOT: linalg.yield
- // CHECK: arith.mulf [[T1]], [[ARG2]]
- // CHECK: linalg.yield
- %5 = arith.mulf %arg5, %arg6 : f32
- linalg.yield %5 : f32
- }
- return
-}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 916c04f..44c81c3 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
-> tensor<8x8xf32>
return %res : tensor<8x8xf32>
}
+
+// -----
+
+func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) {
+ // expected-error @+1 {{expected to have pure tensor or buffer semantics}}
+ linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%c: memref<?x?xf32>)
+ return
+}
+
diff --git a/mlir/test/Dialect/Mesh/spmdization.mlir b/mlir/test/Dialect/Mesh/spmdization.mlir
index 2fb8029..572d3eb 100644
--- a/mlir/test/Dialect/Mesh/spmdization.mlir
+++ b/mlir/test/Dialect/Mesh/spmdization.mlir
@@ -127,3 +127,17 @@ func.func @multiple_chained_ops(
// CHECK: return %[[RESHARD3]] : tensor<1xi8>
return %7 : tensor<2xi8>
}
+
+// CHECK-LABEL: func @incomplete_sharding
+func.func @incomplete_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<4x16xf32>
+ %arg0: tensor<8x16xf32>
+// CHECK-SAME: -> tensor<4x16xf32> {
+) -> tensor<8x16xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> annotate_for_users : tensor<8x16xf32>
+ // CHECK: %[[RES:.*]] = tosa.sigmoid %[[ARG]] : (tensor<4x16xf32>) -> tensor<4x16xf32>
+ %1 = tosa.sigmoid %0 : (tensor<8x16xf32>) -> tensor<8x16xf32>
+ %2 = mesh.shard %1 to <@mesh_1d, [[0]]> : tensor<8x16xf32>
+ // CHECK: return %[[RES]] : tensor<4x16xf32>
+ return %2 : tensor<8x16xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 3d68464..01b2707 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -376,6 +376,13 @@ func.func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
}
// -----
+// CHECK-LABEL: cos
+func.func @test_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.cos %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
// CHECK-LABEL: exp
func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.exp %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
@@ -425,6 +432,13 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
}
// -----
+// CHECK-LABEL: sin
+func.func @test_sin(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
+ %0 = tosa.sin %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
// CHECK-LABEL: select
func.func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.select %arg0, %arg1, %arg2 : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a8..94e78ce 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK-NOT: arith.extsi
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: shuffle
- // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK: %[[EXT:.*]] = vector.interleave
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
// CHECK-LABEL: func.func @i7_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
return %0 : vector<16x8xi7>
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 1775b5f..788ae9a 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -83,7 +83,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
@@ -92,7 +92,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
+// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
@@ -459,3 +459,37 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK-128B-LABEL: func @fold_unit_dims_entirely(
// CHECK-128B-NOT: memref.collapse_shape
+
+// -----
+
+func.func @regression_non_contiguous_dim_read(%subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_1 = arith.constant 0.000000e+00 : f32
+ %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ return %8 : vector<2x2xf32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func.func @regression_non_contiguous_dim_read(
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+
+// CHECK-128B-LABEL: func @regression_non_contiguous_dim_read(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+func.func @unsupported_non_contiguous_dim_write(%value : vector<2x2xf32>,
+ %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index, %idx1 : index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @unsupported_non_contiguous_dim_write(
+// CHECK-NOT: memref.collapse_shape
+
+// CHECK-128B-LABEL: func @unsupported_non_contiguous_dim_write(
+// CHECK-128B-NOT: memref.collapse_shape
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 44ff1af..12f13e8 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -1,13 +1,8 @@
// RUN: mlir-opt %s \
-// RUN: -transform-interpreter \
-// RUN: -test-transform-dialect-erase-schedule \
+// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -lower-vector-mask \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN: -test-lower-to-llvm | \
+// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=entry -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index c781d5e..34c5351 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -1,12 +1,7 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
-// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN: -convert-arm-sme-to-llvm \
-// RUN: -convert-vector-to-llvm=enable-arm-sve \
-// RUN: -cse -canonicalize -test-lower-to-llvm | \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index 31c3202..2bfdaa8 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -1,12 +1,6 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
-// RUN: -canonicalize \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN: -convert-arm-sme-to-llvm \
-// RUN: -convert-vector-to-llvm=enable-arm-sve \
-// RUN: -cse -canonicalize -test-lower-to-llvm | \
+// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index d5c3506..e376bdd 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -1,11 +1,7 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN: -arm-sme-vector-legalization -canonicalize -cse \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
-// RUN: -test-lower-to-llvm | \
+// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index 42fe21c..ee3866de 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -1,10 +1,5 @@
// RUN: mlir-opt %s \
-// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// RUN: -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \
-// RUN: -canonicalize -test-lower-to-llvm -verify-diagnostics | \
+// RUN: -test-lower-to-arm-sme -test-lower-to-llvm -verify-diagnostics | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
index 59b4a7e..06b1c10 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir
@@ -1,9 +1,5 @@
// DEFINE: %{entry_point} = test_load_store_zaq0
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 064141c..27be801 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,9 +1,5 @@
// DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
index 0827d9b..9d836d9 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -1,10 +1,4 @@
-// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
-// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
-// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
-// RUN: -convert-arm-sme-to-llvm \
-// RUN: -convert-vector-to-llvm=enable-arm-sve \
-// RUN: -cse -canonicalize -test-lower-to-llvm | \
+// RUN: mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
index f081838..a06ad37 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
@@ -1,11 +1,7 @@
+// DEFINE: %{opts} =
// DEFINE: %{entry} = main
-// DEFINE: %{fusion_opts} = -arm-sme-outer-product-fusion
// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{fusion_opts} \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm -o %t
+// DEFINE: -test-lower-to-arm-sme=%{opts} -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry} -entry-point-result=void \
@@ -18,7 +14,7 @@
// Check result is the same when outerproducts are not combined into widening
// variant.
-// REDEFINE: %{fusion_opts} =
+// REDEFINE: %{opts} = fuse-outer-products=false
// RUN: %{run} | FileCheck %s
func.func @main() {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 5f41b37..7e7869d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,10 +1,6 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm -o %t
+// DEFINE: -test-lower-to-arm-sme -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index a1bb9b7..46bf799 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,10 +1,6 @@
// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm -o %t
+// DEFINE: -test-lower-to-arm-sme -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
index 1770e57..9a353ec 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
@@ -1,11 +1,5 @@
// DEFINE: %{entry} = main
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// DEFINE: -arm-sme-outer-product-fusion \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 6e028d5..52f5688 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,9 +1,5 @@
// DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index c0c1f55..710cc66 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,10 +1,5 @@
// DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index eee3c56..88bc0d0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,9 +1,5 @@
// DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 223bc8c..e149174 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,8 +1,4 @@
-// RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// RUN: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN: -test-lower-to-llvm | \
+// RUN: mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -march=aarch64 -mattr=+sve,+sme \
// RUN: -e entry -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 2f151e2..b29790db 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,9 +1,5 @@
// DEFINE: %{entry_point} = za0_d_f64
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE: -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index f28bf19..c8c401b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,8 +1,5 @@
// DEFINE: %{entry_point} = entry
-// DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
-// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
-// DEFINE: -convert-arm-sme-to-llvm -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
new file mode 100644
index 0000000..07989bd
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_arm_runner_utils \
+// RUN: -march=aarch64 -mattr=+sve | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %v1 = vector.splat %f1 : vector<[4]xf32>
+ %v2 = vector.splat %f2 : vector<[4]xf32>
+ vector.print %v1 : vector<[4]xf32>
+ vector.print %v2 : vector<[4]xf32>
+ //
+ // Test vectors:
+ //
+ // CHECK: ( 1, 1, 1, 1
+ // CHECK: ( 2, 2, 2, 2
+
+ %v3 = vector.interleave %v1, %v2 : vector<[4]xf32>
+ vector.print %v3 : vector<[8]xf32>
+ // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
new file mode 100644
index 0000000..0bc78af
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt %s -test-lower-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @entry() {
+ %f1 = arith.constant 1.0 : f32
+ %f2 = arith.constant 2.0 : f32
+ %v1 = vector.splat %f1 : vector<2x4xf32>
+ %v2 = vector.splat %f2 : vector<2x4xf32>
+ vector.print %v1 : vector<2x4xf32>
+ vector.print %v2 : vector<2x4xf32>
+ //
+ // Test vectors:
+ //
+ // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) )
+ // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) )
+
+ %v3 = vector.interleave %v1, %v2 : vector<2x4xf32>
+ vector.print %v3 : vector<2x8xf32>
+ // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) )
+
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 0962134..9a4e939 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -59,13 +59,15 @@ define void @unhandled_intrinsic() gc "example" {
; // -----
+; Check that debug intrinsics with an unsupported argument are dropped.
+
declare void @llvm.dbg.value(metadata, metadata, metadata)
; CHECK: import-failure.ll
-; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5
+; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !DIArgList(i64 %{{.*}}, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value))
; CHECK: import-failure.ll
-; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5
-define void @dropped_instruction(i64 %arg1) {
+; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression())
+define void @unsupported_argument(i64 %arg1) {
call void @llvm.dbg.value(metadata !DIArgList(i64 %arg1, i64 undef), metadata !3, metadata !DIExpression(DW_OP_LLVM_arg, 0, DW_OP_LLVM_arg, 1, DW_OP_constu, 1, DW_OP_mul, DW_OP_plus, DW_OP_stack_value)), !dbg !5
call void @llvm.dbg.value(metadata !6, metadata !3, metadata !DIExpression()), !dbg !5
ret void
@@ -83,6 +85,38 @@ define void @dropped_instruction(i64 %arg1) {
; // -----
+; Check that debug intrinsics that depend on cyclic metadata are dropped.
+
+declare void @llvm.dbg.value(metadata, metadata, metadata)
+
+; CHECK: import-failure.ll
+; CHECK-SAME: warning: dropped instruction: call void @llvm.dbg.label(metadata !{{.*}})
+; CHECK: import-failure.ll
+; CHECK-SAME: warning: dropped intrinsic: call void @llvm.dbg.value(metadata i64 %{{.*}}, metadata !3, metadata !DIExpression())
+define void @cylic_metadata(i64 %arg1) {
+ call void @llvm.dbg.value(metadata i64 %arg1, metadata !10, metadata !DIExpression()), !dbg !14
+ call void @llvm.dbg.label(metadata !13), !dbg !14
+ ret void
+}
+
+!llvm.dbg.cu = !{!1}
+!llvm.module.flags = !{!0}
+!0 = !{i32 2, !"Debug Info Version", i32 3}
+!1 = distinct !DICompileUnit(language: DW_LANG_C, file: !2)
+!2 = !DIFile(filename: "import-failure.ll", directory: "/")
+!3 = !DICompositeType(tag: DW_TAG_array_type, size: 42, baseType: !4)
+!4 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !3)
+!5 = distinct !DISubprogram(name: "class_method", scope: !2, file: !2, type: !6, spFlags: DISPFlagDefinition, unit: !1)
+!6 = !DISubroutineType(types: !7)
+!7 = !{!3}
+!10 = !DILocalVariable(scope: !5, name: "arg1", file: !2, line: 1, arg: 1, align: 64);
+!11 = !DILexicalBlock(scope: !5)
+!12 = !DILexicalBlockFile(scope: !11, discriminator: 0)
+!13 = !DILabel(scope: !12, name: "label", file: !2, line: 42)
+!14 = !DILocation(line: 1, column: 2, scope: !5)
+
+; // -----
+
; global_dtors with non-null data fields cannot be represented in MLIR.
; CHECK: <unknown>
; CHECK-SAME: error: unhandled global variable: @llvm.global_dtors
diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
new file mode 100644
index 0000000..e942c7b
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
@@ -0,0 +1,18 @@
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRArmSMETestPasses
+ TestLowerToArmSME.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRArithToArmSME
+ MLIRArmSMEToLLVM
+ MLIRArmSMEToSCF
+ MLIRArmSMETransforms
+ MLIRArmSVETransforms
+ MLIRIR
+ MLIRPass
+ MLIRTransforms
+ MLIRVectorToArmSME
+ MLIRVectorToSCF
+ )
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
new file mode 100644
index 0000000..48d4a58
--- /dev/null
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -0,0 +1,99 @@
+//===- TestLowerToArmSME.cpp - Test lowering to ArmSME as a sink pass -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing the lowering to ArmSME as a
+// generally usable sink pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
+#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
+#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Pass/PassOptions.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+struct TestLowerToArmSMEOptions
+ : public PassPipelineOptions<TestLowerToArmSMEOptions> {
+ PassOptions::Option<bool> fuseOuterProducts{
+ *this, "fuse-outer-products",
+ llvm::cl::desc("Fuse outer product operations via "
+ "'-arm-sme-outer-product-fusion' pass"),
+ llvm::cl::init(true)};
+};
+
+void buildTestLowerToArmSME(OpPassManager &pm,
+ const TestLowerToArmSMEOptions &options) {
+ // Legalize vector operations so they can be converted to ArmSME.
+ pm.addPass(arm_sme::createVectorLegalizationPass());
+
+ // Sprinkle some cleanups.
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+
+ // Passes that convert operations on vectors to ArmSME operations.
+
+ // Convert Arith to ArmSME.
+ pm.addPass(createArithToArmSMEConversionPass());
+ // Convert Vector to ArmSME.
+ pm.addPass(createConvertVectorToArmSMEPass());
+
+ // Fuse outer products.
+ if (options.fuseOuterProducts)
+ pm.addPass(arm_sme::createOuterProductFusionPass());
+
+ // Convert operations on high-level vectors to loops.
+
+ // Convert ArmSME to SCF.
+ pm.addPass(createConvertArmSMEToSCFPass());
+
+ // Convert Vector to SCF (with full unroll enabled).
+ pm.addPass(createConvertVectorToSCFPass(
+ VectorTransferToSCFOptions().enableFullUnroll()));
+
+ // Allocate tiles for ArmSME operations.
+ //
+ // Later passes may create further ArmSME ops that implement the
+ // ArmSMETileOpInterface, but tiles are allocated for root operations,
+ // all of which should now exist.
+ pm.addPass(arm_sme::createTileAllocationPass());
+
+ // Enable streaming-mode and ZA.
+ pm.addPass(arm_sme::createEnableArmStreamingPass(
+ arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
+ /*onlyIfRequiredByOps=*/true));
+
+ // Convert ArmSME to LLVM.
+ pm.addPass(createConvertArmSMEToLLVMPass());
+
+ // Sprinkle some cleanups.
+ pm.addPass(createCanonicalizerPass());
+ pm.addPass(createCSEPass());
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLowerToArmSME() {
+ PassPipelineRegistration<TestLowerToArmSMEOptions>(
+ "test-lower-to-arm-sme",
+ "An example pipeline to lower operations on vectors (arith, vector) to "
+ "LLVM via ArmSME.",
+ buildTestLowerToArmSME);
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 30a17c2..e20cd44 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -1,5 +1,6 @@
add_subdirectory(Affine)
add_subdirectory(Arith)
+add_subdirectory(ArmSME)
add_subdirectory(Bufferization)
add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 108cfe8..bde4255 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1152,8 +1152,10 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
- if (failed(applyPartialConversion(
- getOperation(), target, std::move(patterns), &unlegalizedOps))) {
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
}
// Emit remarks for each legalizable operation.
@@ -1181,8 +1183,10 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
+ ConversionConfig config;
+ config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
- std::move(patterns), legalizedOps)))
+ std::move(patterns), config)))
return signalPassFailure();
// Emit remarks for each legalizable operation.
@@ -1806,8 +1810,10 @@ struct TestMergeBlocksPatternDriver
});
DenseSet<Operation *> unlegalizedOps;
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
(void)applyPartialConversion(getOperation(), target, std::move(patterns),
- &unlegalizedOps);
+ config);
for (auto *op : unlegalizedOps)
op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
}
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 68aa6ba..701fc46 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -17,6 +17,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestFuncToLLVM
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
+ MLIRArmSMETestPasses
MLIRBufferizationTestPasses
MLIRControlFlowTestPasses
MLIRDLTITestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index f11c6b4..4dfa05c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -109,6 +109,7 @@ void registerTestLoopFusion();
void registerTestCFGLoopInfoPass();
void registerTestLoopMappingPass();
void registerTestLoopUnrollingPass();
+void registerTestLowerToArmSME();
void registerTestLowerToLLVM();
void registerTestMakeIsolatedFromAbovePass();
void registerTestMatchReductionPass();
@@ -233,6 +234,7 @@ void registerTestPasses() {
mlir::test::registerTestCFGLoopInfoPass();
mlir::test::registerTestLoopMappingPass();
mlir::test::registerTestLoopUnrollingPass();
+ mlir::test::registerTestLowerToArmSME();
mlir::test::registerTestLowerToLLVM();
mlir::test::registerTestMakeIsolatedFromAbovePass();
mlir::test::registerTestMatchReductionPass();
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 3a6bcbd..9d2f690 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -77,7 +77,7 @@ protected:
}
// Inserts an Integer or a Vector of Integers constant of value 'val'.
- spirv::ConstantOp AddConstInt(Type type, const APInt &val) {
+ spirv::ConstantOp addConstInt(Type type, const APInt &val) {
OpBuilder builder(module->getRegion());
auto loc = UnknownLoc::get(&context);
@@ -181,8 +181,8 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
signedInt16Type.getSignedness());
- AddConstInt(signlessInt16Type, signlessIntConstVal);
- AddConstInt(signedInt16Type, signedIntConstVal);
+ addConstInt(signlessInt16Type, signlessIntConstVal);
+ addConstInt(signedInt16Type, signedIntConstVal);
ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary)));
auto hasSignlessVal = [&](spirv::Opcode opcode, ArrayRef<uint32_t> operands) {
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index 2e1309a..16de34c 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -421,7 +421,7 @@ TEST(InterfaceAttachmentTest, PromisedInterfaces) {
// Attribute interfaces use the exact same mechanism as types, so just check
// that the promise mechanism works for attributes.
MLIRContext context;
- auto testDialect = context.getOrLoadDialect<test::TestDialect>();
+ auto *testDialect = context.getOrLoadDialect<test::TestDialect>();
auto attr = test::SimpleAAttr::get(&context);
// `SimpleAAttr` doesn't implement nor promises the
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 8a4f67b..9d75615 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -295,9 +295,9 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
MLIRContext context;
context.getOrLoadDialect<test::TestDialect>();
- auto op1 = createOp(&context);
+ auto *op1 = createOp(&context);
// `op1` has an unknown loc.
- auto op2 = createOp(&context);
+ auto *op2 = createOp(&context);
op2->setLoc(NameLoc::get(StringAttr::get(&context, "foo")));
auto getHash = [](Operation *op, OperationEquivalence::Flags flags) {
return OperationEquivalence::computeHash(
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index a00ebba..26bfbd5 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -37,7 +37,7 @@ using namespace mlir;
class MLIRTargetLLVMNVVM : public ::testing::Test {
protected:
- virtual void SetUp() {
+ void SetUp() override {
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
registerGPUDialectTranslation(registry);
@@ -85,7 +85,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMMToLLVM)) {
serializer.serializeToObject(gpuModule, options);
// Check that the serializer was successful.
ASSERT_TRUE(object != std::nullopt);
- ASSERT_TRUE(object->size() > 0);
+ ASSERT_TRUE(!object->empty());
// Read the serialized module.
llvm::MemoryBufferRef buffer(StringRef(object->data(), object->size()),
@@ -121,7 +121,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) {
serializer.serializeToObject(gpuModule, options);
// Check that the serializer was successful.
ASSERT_TRUE(object != std::nullopt);
- ASSERT_TRUE(object->size() > 0);
+ ASSERT_TRUE(!object->empty());
ASSERT_TRUE(
StringRef(object->data(), object->size()).contains("nvvm_kernel"));
@@ -151,6 +151,6 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
serializer.serializeToObject(gpuModule, options);
// Check that the serializer was successful.
ASSERT_TRUE(object != std::nullopt);
- ASSERT_TRUE(object->size() > 0);
+ ASSERT_TRUE(!object->empty());
}
}