aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2020-12-17 12:24:45 -0800
committerRiver Riddle <riddleriver@gmail.com>2020-12-17 13:01:36 -0800
commit1b97cdf885d6455841280b8da858835e641ee941 (patch)
treeacd8503bb6e0b42cb9a06d3670742c5f958990ae
parent511cfe9441955f20a8b93573fb9b62433b053550 (diff)
downloadllvm-1b97cdf885d6455841280b8da858835e641ee941.zip
llvm-1b97cdf885d6455841280b8da858835e641ee941.tar.gz
llvm-1b97cdf885d6455841280b8da858835e641ee941.tar.bz2
[mlir][IR][NFC] Move context/location parameters of builtin Type::get methods to the start of the parameter list
This better matches the rest of the infrastructure, is much simpler, and makes it easier to move these types to being declaratively specified. Differential Revision: https://reviews.llvm.org/D93432
-rw-r--r--flang/include/flang/Optimizer/Dialect/FIROps.td2
-rw-r--r--flang/lib/Lower/ConvertType.cpp4
-rw-r--r--flang/lib/Lower/IntrinsicCall.cpp18
-rw-r--r--flang/lib/Lower/RTBuilder.h22
-rw-r--r--flang/lib/Optimizer/Dialect/FIROps.cpp5
-rw-r--r--mlir/include/mlir/Dialect/AVX512/AVX512.td8
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td2
-rw-r--r--mlir/include/mlir/IR/BuiltinTypes.h47
-rw-r--r--mlir/include/mlir/IR/OpBase.td5
-rw-r--r--mlir/lib/CAPI/IR/BuiltinTypes.cpp28
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp28
-rw-r--r--mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp2
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp2
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp2
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp10
-rw-r--r--mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp6
-rw-r--r--mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/Utils.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp6
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/VectorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/VectorTransforms.cpp2
-rw-r--r--mlir/lib/IR/Builders.cpp14
-rw-r--r--mlir/lib/IR/BuiltinDialect.cpp2
-rw-r--r--mlir/lib/IR/BuiltinTypes.cpp47
-rw-r--r--mlir/lib/IR/Dialect.cpp2
-rw-r--r--mlir/lib/IR/MLIRContext.cpp18
-rw-r--r--mlir/lib/IR/Operation.cpp2
-rw-r--r--mlir/lib/IR/Value.cpp2
-rw-r--r--mlir/lib/Parser/DialectSymbolParser.cpp4
-rw-r--r--mlir/lib/Parser/TypeParser.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization.cpp4
-rw-r--r--mlir/lib/Target/SPIRV/Serialization.cpp2
-rw-r--r--mlir/lib/Transforms/BufferResultsToOutParams.cpp4
-rw-r--r--mlir/lib/Transforms/NormalizeMemRefs.cpp11
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp4
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp14
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp2
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp4
-rw-r--r--mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp10
-rw-r--r--mlir/unittests/IR/AttributeTest.cpp20
-rw-r--r--mlir/unittests/TableGen/StructsGenTest.cpp4
52 files changed, 205 insertions, 209 deletions
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 8d7a6d4..cecd1cb 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2176,7 +2176,7 @@ def fir_DispatchOp : fir_Op<"dispatch",
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
auto resTy{getResultTypes()};
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
- p << " : " << mlir::FunctionType::get(argTy, resTy, getContext());
+ p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy);
}];
let extraClassDeclaration = [{
diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp
index 746d7ad..b3fa85d 100644
--- a/flang/lib/Lower/ConvertType.cpp
+++ b/flang/lib/Lower/ConvertType.cpp
@@ -49,7 +49,7 @@ mlir::Type genFIRType(mlir::MLIRContext *context) {
if constexpr (TC == Fortran::common::TypeCategory::Integer) {
auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,
KIND>::Scalar::bits};
- return mlir::IntegerType::get(bits, context);
+ return mlir::IntegerType::get(context, bits);
} else if constexpr (TC == Fortran::common::TypeCategory::Logical ||
TC == Fortran::common::TypeCategory::Character ||
TC == Fortran::common::TypeCategory::Complex) {
@@ -278,7 +278,7 @@ private:
// some sequence of `n` bytes
mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) {
- mlir::Type byteTy{mlir::IntegerType::get(8, context)};
+ mlir::Type byteTy{mlir::IntegerType::get(context, 8)};
return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy);
}
diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp
index 0e0081e..7053cd9 100644
--- a/flang/lib/Lower/IntrinsicCall.cpp
+++ b/flang/lib/Lower/IntrinsicCall.cpp
@@ -298,26 +298,26 @@ static constexpr RuntimeFunction pgmathPrecise[] = {
static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF32(context);
- return mlir::FunctionType::get({t}, {t}, context);
+ return mlir::FunctionType::get(context, {t}, {t});
}
static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF64(context);
- return mlir::FunctionType::get({t}, {t}, context);
+ return mlir::FunctionType::get(context, {t}, {t});
}
template <int Bits>
static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF64(context);
- auto r = mlir::IntegerType::get(Bits, context);
- return mlir::FunctionType::get({t}, {r}, context);
+ auto r = mlir::IntegerType::get(context, Bits);
+ return mlir::FunctionType::get(context, {t}, {r});
}
template <int Bits>
static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF32(context);
- auto r = mlir::IntegerType::get(Bits, context);
- return mlir::FunctionType::get({t}, {r}, context);
+ auto r = mlir::IntegerType::get(context, Bits);
+ return mlir::FunctionType::get(context, {t}, {r});
}
// TODO : Fill-up this table with more intrinsic.
@@ -585,8 +585,8 @@ getFunctionType(mlir::Type resultType, llvm::ArrayRef<mlir::Value> arguments,
llvm::SmallVector<mlir::Type, 2> argumentTypes;
for (auto &arg : arguments)
argumentTypes.push_back(arg.getType());
- return mlir::FunctionType::get(argumentTypes, resultType,
- builder.getModule().getContext());
+ return mlir::FunctionType::get(builder.getModule().getContext(),
+ argumentTypes, resultType);
}
/// fir::ExtendedValue to mlir::Value translation layer
@@ -1144,7 +1144,7 @@ mlir::Value IntrinsicLibrary::genMerge(mlir::Type,
llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 3);
- auto i1Type = mlir::IntegerType::get(1, builder.getContext());
+ auto i1Type = mlir::IntegerType::get(builder.getContext(), 1);
auto mask = builder.createConvert(loc, i1Type, args[2]);
return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]);
}
diff --git a/flang/lib/Lower/RTBuilder.h b/flang/lib/Lower/RTBuilder.h
index 4805eb3..4b130b6 100644
--- a/flang/lib/Lower/RTBuilder.h
+++ b/flang/lib/Lower/RTBuilder.h
@@ -48,7 +48,7 @@ static constexpr TypeBuilderFunc getModel();
template <>
constexpr TypeBuilderFunc getModel<int>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return mlir::IntegerType::get(8 * sizeof(int), context);
+ return mlir::IntegerType::get(context, 8 * sizeof(int));
};
}
template <>
@@ -61,14 +61,14 @@ constexpr TypeBuilderFunc getModel<int &>() {
template <>
constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return mlir::IntegerType::get(8 * sizeof(Fortran::runtime::io::Iostat),
- context);
+ return mlir::IntegerType::get(context,
+ 8 * sizeof(Fortran::runtime::io::Iostat));
};
}
template <>
constexpr TypeBuilderFunc getModel<char *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return fir::ReferenceType::get(mlir::IntegerType::get(8, context));
+ return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
};
}
template <>
@@ -78,26 +78,26 @@ constexpr TypeBuilderFunc getModel<const char *>() {
template <>
constexpr TypeBuilderFunc getModel<const char16_t *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return fir::ReferenceType::get(mlir::IntegerType::get(16, context));
+ return fir::ReferenceType::get(mlir::IntegerType::get(context, 16));
};
}
template <>
constexpr TypeBuilderFunc getModel<const char32_t *>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return fir::ReferenceType::get(mlir::IntegerType::get(32, context));
+ return fir::ReferenceType::get(mlir::IntegerType::get(context, 32));
};
}
template <>
constexpr TypeBuilderFunc getModel<void **>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get(
- fir::PointerType::get(mlir::IntegerType::get(8, context)));
+ fir::PointerType::get(mlir::IntegerType::get(context, 8)));
};
}
template <>
constexpr TypeBuilderFunc getModel<std::int64_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return mlir::IntegerType::get(64, context);
+ return mlir::IntegerType::get(context, 64);
};
}
template <>
@@ -110,7 +110,7 @@ constexpr TypeBuilderFunc getModel<std::int64_t &>() {
template <>
constexpr TypeBuilderFunc getModel<std::size_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return mlir::IntegerType::get(8 * sizeof(std::size_t), context);
+ return mlir::IntegerType::get(context, 8 * sizeof(std::size_t));
};
}
template <>
@@ -146,7 +146,7 @@ constexpr TypeBuilderFunc getModel<float &>() {
template <>
constexpr TypeBuilderFunc getModel<bool>() {
return [](mlir::MLIRContext *context) -> mlir::Type {
- return mlir::IntegerType::get(1, context);
+ return mlir::IntegerType::get(context, 1);
};
}
template <>
@@ -190,7 +190,7 @@ struct RuntimeTableKey<RT(ATs...)> {
llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys;
for (auto f : args)
argTys.push_back(f(ctxt));
- return mlir::FunctionType::get(argTys, {retTy}, ctxt);
+ return mlir::FunctionType::get(ctxt, argTys, {retTy});
};
}
};
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4a6c8d5..e8d8d6c 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -151,7 +151,7 @@ mlir::Type fir::BoxDimsOp::getTupleType() {
// note: triple, but 4 is nearest power of 2
llvm::SmallVector<mlir::Type, 4> triple{
getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
- return mlir::TupleType::get(triple, getContext());
+ return mlir::TupleType::get(getContext(), triple);
}
//===----------------------------------------------------------------------===//
@@ -171,7 +171,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
auto resultTypes{op.getResultTypes()};
llvm::SmallVector<Type, 8> argTypes(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
- p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
+ p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes);
}
static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
@@ -1565,4 +1565,3 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
#define GET_OP_CLASSES
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
-
diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td
index 56c4547..ecfff4b 100644
--- a/mlir/include/mlir/Dialect/AVX512/AVX512.td
+++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td
@@ -35,8 +35,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
AllTypesMatch<["src", "a", "dst"]>,
TypesMatchWith<"imm has the same number of bits as elements in dst",
"dst", "imm",
- "IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
- " $_self.getContext())">]> {
+ "IntegerType::get($_self.getContext(), "
+ "($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "Masked roundscale op";
let description = [{
The mask.rndscale op is an AVX512 specific op that can lower to the proper
@@ -67,8 +67,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
AllTypesMatch<["src", "a", "b", "dst"]>,
TypesMatchWith<"k has the same number of bits as elements in dst",
"dst", "k",
- "IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
- " $_self.getContext())">]> {
+ "IntegerType::get($_self.getContext(), "
+ "($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "ScaleF op";
let description = [{
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
index 380eaed..74ca666 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -911,7 +911,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>();
unsigned i = 0;
- auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
+ auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
getOperation()->setAttr("operand_segment_sizes", newAttr);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 10e78e5..3bfb3ce4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -63,7 +63,7 @@ public:
/// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type
/// isn't supported.
- static ComplexType getChecked(Type elementType, Location location);
+ static ComplexType getChecked(Location location, Type elementType);
/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Location loc,
@@ -93,27 +93,27 @@ public:
/// The created IntegerType is signless (i.e., no signedness semantics).
/// Assume the width is within the allowed range and assert on failures. Use
/// getChecked to handle failures gracefully.
- static IntegerType get(unsigned width, MLIRContext *context);
+ static IntegerType get(MLIRContext *context, unsigned width);
/// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType has signedness semantics as indicated via
/// `signedness`. Assume the width is within the allowed range and assert on
/// failures. Use getChecked to handle failures gracefully.
- static IntegerType get(unsigned width, SignednessSemantics signedness,
- MLIRContext *context);
+ static IntegerType get(MLIRContext *context, unsigned width,
+ SignednessSemantics signedness);
/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType is signless (i.e., no signedness semantics). If the width is
/// outside the allowed range, emit errors and return a null type.
- static IntegerType getChecked(unsigned width, Location location);
+ static IntegerType getChecked(Location location, unsigned width);
/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created
/// IntegerType has signedness semantics as indicated via `signedness`. If the
/// width is outside the allowed range, emit errors and return a null type.
- static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
- Location location);
+ static IntegerType getChecked(Location location, unsigned width,
+ SignednessSemantics signedness);
/// Verify the construction of an integer type.
static LogicalResult
@@ -180,8 +180,8 @@ class FunctionType
public:
using Base::Base;
- static FunctionType get(TypeRange inputs, TypeRange results,
- MLIRContext *context);
+ static FunctionType get(MLIRContext *context, TypeRange inputs,
+ TypeRange results);
/// Input types.
unsigned getNumInputs() const;
@@ -211,14 +211,14 @@ public:
using Base::Base;
/// Get or create a new OpaqueType with the provided dialect and string data.
- static OpaqueType get(Identifier dialect, StringRef typeData,
- MLIRContext *context);
+ static OpaqueType get(MLIRContext *context, Identifier dialect,
+ StringRef typeData);
/// Get or create a new OpaqueType with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
/// null type is returned.
- static OpaqueType getChecked(Identifier dialect, StringRef typeData,
- MLIRContext *context, Location location);
+ static OpaqueType getChecked(Location location, Identifier dialect,
+ StringRef typeData);
/// Returns the dialect namespace of the opaque type.
Identifier getDialectNamespace() const;
@@ -335,8 +335,8 @@ public:
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
- static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
- Location location);
+ static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
+ Type elementType);
/// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Location loc,
@@ -394,8 +394,8 @@ public:
/// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type.
- static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
- Location location);
+ static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
+ Type elementType);
/// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc,
@@ -424,7 +424,7 @@ public:
/// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type.
- static UnrankedTensorType getChecked(Type elementType, Location location);
+ static UnrankedTensorType getChecked(Location location, Type elementType);
/// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc,
@@ -527,9 +527,10 @@ public:
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
- static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
+ static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
+ Type elementType,
ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace, Location location);
+ unsigned memorySpace);
ArrayRef<int64_t> getShape() const;
@@ -573,8 +574,8 @@ public:
/// type and memory space declared at the given, potentially unknown,
/// location. If the UnrankedMemRefType defined by the arguments would be
/// ill-formed, emit errors and return a nullptr-wrapping type.
- static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
- Location location);
+ static UnrankedMemRefType getChecked(Location location, Type elementType,
+ unsigned memorySpace);
/// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Location loc,
@@ -600,7 +601,7 @@ public:
/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
- static TupleType get(TypeRange elementTypes, MLIRContext *context);
+ static TupleType get(MLIRContext *context, TypeRange elementTypes);
/// Get or create an empty tuple type.
static TupleType get(MLIRContext *context);
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index aa5ef28..552ac75 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -475,8 +475,9 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
class OpaqueType<string dialect, string name, string description>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
description>,
- BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\""
- # dialect # "\"), \"" # name # "\", $_builder.getContext())">;
+ BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
+ "$_builder.getIdentifier(\"" # dialect # "\"), \""
+ # name # "\")">;
// Function Type
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 5d738b4..f0c2901 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -26,15 +26,15 @@ bool mlirTypeIsAInteger(MlirType type) {
}
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
- return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
+ return wrap(IntegerType::get(unwrap(ctx), bitwidth));
}
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
- return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
+ return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
}
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
- return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
+ return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
}
unsigned mlirIntegerTypeGetWidth(MlirType type) {
@@ -172,8 +172,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirLocation loc) {
return wrap(VectorType::getChecked(
- llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- unwrap(loc)));
+ unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType)));
}
//===----------------------------------------------------------------------===//
@@ -201,8 +201,8 @@ MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType,
MlirLocation loc) {
return wrap(RankedTensorType::getChecked(
- llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- unwrap(loc)));
+ unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
@@ -211,7 +211,7 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
MlirLocation loc) {
- return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
+ return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
//===----------------------------------------------------------------------===//
@@ -244,8 +244,8 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
unsigned memorySpace,
MlirLocation loc) {
return wrap(MemRefType::getChecked(
- llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::None, memorySpace, unwrap(loc)));
+ unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType), llvm::None, memorySpace));
}
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
@@ -272,8 +272,8 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
unsigned memorySpace,
MlirLocation loc) {
- return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
- unwrap(loc)));
+ return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
+ memorySpace));
}
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
@@ -290,7 +290,7 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType const *elements) {
SmallVector<Type, 4> types;
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
- return wrap(TupleType::get(typeRef, unwrap(ctx)));
+ return wrap(TupleType::get(unwrap(ctx), typeRef));
}
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
@@ -316,7 +316,7 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
SmallVector<Type, 4> resultsList;
(void)unwrapList(numInputs, inputs, inputsList);
(void)unwrapList(numResults, results, resultsList);
- return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
+ return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
}
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 75c23d3..273754f 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -53,52 +53,52 @@ namespace {
struct AsyncAPI {
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
- auto count = IntegerType::get(32, ctx);
- return FunctionType::get({ref, count}, {}, ctx);
+ auto count = IntegerType::get(ctx, 32);
+ return FunctionType::get(ctx, {ref, count}, {});
}
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
+ return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
}
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
- return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
+ return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
}
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
+ return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
- return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
+ return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
}
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
- return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
+ return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
}
static FunctionType executeFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo();
- return FunctionType::get({hdl, resume}, {}, ctx);
+ return FunctionType::get(ctx, {hdl, resume}, {});
}
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
- auto i64 = IntegerType::get(64, ctx);
- return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
- ctx);
+ auto i64 = IntegerType::get(ctx, 64);
+ return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
+ {i64});
}
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo();
- return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
+ return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
}
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo();
- return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
+ return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
}
// Auxiliary coroutine resume intrinsic wrapper.
@@ -690,7 +690,7 @@ public:
if (!addToGroup.operand().getType().isa<TokenType>())
return failure();
- auto i64 = IntegerType::get(64, op->getContext());
+ auto i64 = IntegerType::get(op->getContext(), 64);
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
return success();
}
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 9582756..887d3e7 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -122,7 +122,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
}
// Declare vulkan launch function.
- auto funcType = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext());
+ auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
return success();
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 47e8f27..5546c82 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -84,7 +84,7 @@ static LLVMType getPtrToElementType(T containerType,
/// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext();
- auto int64Ty = converter.convertType(IntegerType::get(64, context))
+ auto int64Ty = converter.convertType(IntegerType::get(context, 64))
.cast<LLVM::LLVMType>();
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
}
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index f884a15..8b53ecb 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -65,7 +65,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
- auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
+ auto libFnType = rewriter.getFunctionType(inputTypes, {});
OpBuilder::InsertionGuard guard(rewriter);
// Insert before module terminator.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 22a4310..cacb478 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -407,8 +407,7 @@ public:
// cover all possible corner cases.
if (isSignedIntegerOrVector(srcType) ||
isUnsignedIntegerOrVector(srcType)) {
- auto *context = rewriter.getContext();
- auto signlessType = IntegerType::get(getBitWidth(srcType), context);
+ auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
if (srcType.isa<VectorType>()) {
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index e9d05298..c05e9da 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -584,7 +584,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]);
ArrayRef<Value> ivs(ivsStorage);
- Value pos = std_index_cast(IntegerType::get(32, ctx), ivs.back());
+ Value pos = std_index_cast(IntegerType::get(ctx, 32), ivs.back());
Value inVector = local(ivs.drop_back());
auto loadValue = [&](ArrayRef<Value> indices) {
Value vector = vector_insert_element(remote(indices), inVector, pos);
@@ -671,7 +671,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
ArrayRef<Value> ivs(ivsStorage);
Value pos =
- std_index_cast(IntegerType::get(32, op->getContext()), ivs.back());
+ std_index_cast(IntegerType::get(op->getContext(), 32), ivs.back());
auto storeValue = [&](ArrayRef<Value> indices) {
Value scalar = vector_extract_element(local(ivs.drop_back()), pos);
remote(indices) = scalar;
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 36b3393..2bdf31f 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -152,7 +152,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
int32_t numDependencies = dependencies.size();
int32_t numOperands = operands.size();
auto operandSegmentSizes = DenseIntElementsAttr::get(
- VectorType::get({2}, IntegerType::get(32, result.getContext())),
+ VectorType::get({2}, builder.getIntegerType(32)),
{numDependencies, numOperands});
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
index ea1da59..17f27cb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp
@@ -118,7 +118,7 @@ LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
builder.setInsertionPointToStart(value.getParentBlock());
Location loc = value.getLoc();
- auto i32 = IntegerType::get(32, ctx);
+ auto i32 = IntegerType::get(ctx, 32);
// Drop the reference count immediately if the value has no uses.
if (value.getUses().empty()) {
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index d3ee3e2..669c755 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -31,7 +31,7 @@ struct GpuAllReduceRewriter {
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
indexType(IndexType::get(reduceOp.getContext())),
- int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {}
+ int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
/// Creates an all_reduce across the workgroup.
///
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 83bc316..ae9b2fe 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -155,7 +155,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
kernelOperandTypes.push_back(operand.getType());
}
FunctionType type =
- FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
+ FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 9b2c88c..165cfe06 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -120,8 +120,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
- auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
- op.getContext());
+ auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
+ {op.getType()});
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
if (op.alignment().hasValue() && *op.alignment() != 0)
@@ -781,7 +781,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
// Reconstruct the function MLIR function type from operand and result types.
p << " : "
- << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
+ << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
}
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 2410df8..073673b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -76,25 +76,25 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
IntegerAttr alignment_attr;
if (alignment.hasValue())
alignment_attr =
- IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
+ IntegerAttr::get(IntegerType::get(ctx, 64), alignment.getValue());
if (!dynamicBuffers)
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
return options.useAlloca
? std_alloca(MemRefType::get(width * cst.getValue(),
- IntegerType::get(8, ctx)),
+ IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
.value
: std_alloc(MemRefType::get(width * cst.getValue(),
- IntegerType::get(8, ctx)),
+ IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr)
.value;
Value mul =
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
return options.useAlloca
- ? std_alloca(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
+ ? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
.value
- : std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
+ : std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr)
.value;
}
diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
index 37b72af..dd1d058 100644
--- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp
@@ -18,7 +18,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
int64_t &qmax) {
// Hard-coded type mapping from TFLite.
if (numBits <= 8) {
- storageType = IntegerType::get(8, ctx);
+ storageType = IntegerType::get(ctx, 8);
if (isSigned) {
qmin = -128;
qmax = 127;
@@ -27,7 +27,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
qmax = 255;
}
} else if (numBits <= 16) {
- storageType = IntegerType::get(16, ctx);
+ storageType = IntegerType::get(ctx, 16);
if (isSigned) {
qmin = -32768;
qmax = 32767;
@@ -36,7 +36,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
qmax = 65535;
}
} else if (numBits <= 32) {
- storageType = IntegerType::get(32, ctx);
+ storageType = IntegerType::get(ctx, 32);
if (isSigned) {
qmin = std::numeric_limits<int32_t>::min();
qmax = std::numeric_limits<int32_t>::max();
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index aadbf13..6d11bc0 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -79,7 +79,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
int64_t chunkSize =
std::accumulate(std::next(shape.begin(), quantizationDim + 1),
shape.end(), 1, std::multiplies<int64_t>());
- Type newElementType = IntegerType::get(storageBitWidth, attr.getContext());
+ Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
return attr.mapValues(newElementType, [&](const APFloat &old) {
int chunkIndex = (flattenIndex++) / chunkSize;
return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
index c30c823..0d3d2ff 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp
@@ -96,7 +96,7 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
ValueRange values(captures.getArrayRef());
FunctionType type =
- FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
+ FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
BlockAndValueMapping bvm;
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index b6c5f82..644d2bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -123,7 +123,7 @@ spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
assert(localSize.size() == 3);
return spirv::EntryPointABIAttr::get(
DenseElementsAttr::get<int32_t>(
- VectorType::get(3, IntegerType::get(32, context)), localSize)
+ VectorType::get(3, IntegerType::get(context, 32)), localSize)
.cast<DenseIntElementsAttr>(),
context);
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf3a509..b310d5d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -93,7 +93,7 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// instructions. The Vulkan spec requires the builtins like
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
// SExtended to 64-bit for index computations.
- return IntegerType::get(32, context);
+ return IntegerType::get(context, 32);
}
/// Mapping between SPIR-V storage classes to memref memory spaces.
@@ -260,8 +260,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
auto intType = type.cast<IntegerType>();
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
- return IntegerType::get(/*width=*/32, intType.getSignedness(),
- targetEnv.getContext());
+ return IntegerType::get(targetEnv.getContext(), /*width=*/32,
+ intType.getSignedness());
}
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 543c61a..b19264e 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -714,7 +714,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
FunctionType CallOp::getCalleeType() {
- return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
+ return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}
//===----------------------------------------------------------------------===//
@@ -753,7 +753,7 @@ void CallIndirectOp::getCanonicalizationPatterns(
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {
- auto i1Type = IntegerType::get(1, type.getContext());
+ auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>())
@@ -914,7 +914,7 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+ return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index c74ef88..5c1f377 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -1426,7 +1426,7 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
MLIRContext *context) {
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
- return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
+ return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
});
return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
}
@@ -2767,7 +2767,7 @@ static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.resolveOperands(operandInfos, types, loc, result.operands) ||
- parser.addTypeToList(TupleType::get(types, ctx), result.types));
+ parser.addTypeToList(TupleType::get(ctx, types), result.types));
}
static void print(OpAsmPrinter &p, TupleOp op) {
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index d3aa77a..6644269 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -215,7 +215,7 @@ static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
// Create Vector type and add to 'vectorTypes[i]'.
vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
}
- return TupleType::get(vectorTypes, builder.getContext());
+ return builder.getTupleType(vectorTypes);
}
// UnrolledVectorState aggregates per-operand/result vector state required for
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 64eb37c..3d06551 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -52,27 +52,27 @@ FloatType Builder::getF64Type() { return FloatType::getF64(context); }
IndexType Builder::getIndexType() { return IndexType::get(context); }
-IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
+IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
-IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
+IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
-IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
+IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
IntegerType Builder::getIntegerType(unsigned width) {
- return IntegerType::get(width, context);
+ return IntegerType::get(context, width);
}
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
return IntegerType::get(
- width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
+ context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
}
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
- return FunctionType::get(inputs, results, context);
+ return FunctionType::get(context, inputs, results);
}
TupleType Builder::getTupleType(TypeRange elementTypes) {
- return TupleType::get(elementTypes, context);
+ return TupleType::get(context, elementTypes);
}
NoneType Builder::getNoneType() { return NoneType::get(context); }
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index b125ae0..508f0cc 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -179,7 +179,7 @@ FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
inputTypes.push_back(newType.getInput(i));
- newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
+ newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
}
// Create the new function.
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 68cf491..b4fc53d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -35,7 +35,7 @@ ComplexType ComplexType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType);
}
-ComplexType ComplexType::getChecked(Type elementType, Location location) {
+ComplexType ComplexType::getChecked(Location location, Type elementType) {
return Base::getChecked(location, elementType);
}
@@ -76,7 +76,7 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return IntegerType();
- return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
+ return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
}
//===----------------------------------------------------------------------===//
@@ -126,8 +126,8 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
// FunctionType
//===----------------------------------------------------------------------===//
-FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
- MLIRContext *context) {
+FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
+ TypeRange results) {
return Base::get(context, inputs, results);
}
@@ -182,20 +182,20 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
newResultTypes = newResultTypesBuffer;
}
- return get(newInputTypes, newResultTypes, getContext());
+ return get(getContext(), newInputTypes, newResultTypes);
}
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
-OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
- MLIRContext *context) {
+OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
+ StringRef typeData) {
return Base::get(context, dialect, typeData);
}
-OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
- MLIRContext *context, Location location) {
+OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
+ StringRef typeData) {
return Base::getChecked(location, dialect, typeData);
}
@@ -313,8 +313,8 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
return Base::get(elementType.getContext(), shape, elementType);
}
-VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
- Location location) {
+VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
+ Type elementType) {
return Base::getChecked(location, shape, elementType);
}
@@ -379,9 +379,9 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
return Base::get(elementType.getContext(), shape, elementType);
}
-RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
- Type elementType,
- Location location) {
+RankedTensorType RankedTensorType::getChecked(Location location,
+ ArrayRef<int64_t> shape,
+ Type elementType) {
return Base::getChecked(location, shape, elementType);
}
@@ -406,8 +406,8 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType);
}
-UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
- Location location) {
+UnrankedTensorType UnrankedTensorType::getChecked(Location location,
+ Type elementType) {
return Base::getChecked(location, elementType);
}
@@ -448,9 +448,10 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
-MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
+MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
+ Type elementType,
ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace, Location location) {
+ unsigned memorySpace) {
return getImpl(shape, elementType, affineMapComposition, memorySpace,
location);
}
@@ -524,9 +525,9 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
return Base::get(elementType.getContext(), elementType, memorySpace);
}
-UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
- unsigned memorySpace,
- Location location) {
+UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
+ Type elementType,
+ unsigned memorySpace) {
return Base::getChecked(location, elementType, memorySpace);
}
@@ -694,12 +695,12 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
-TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
+TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
return Base::get(context, elementTypes);
}
/// Get or create an empty tuple type.
-TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
+TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }
/// Return the elements types for this tuple.
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index ff2b7a0..beabd48 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -82,7 +82,7 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext());
- return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
+ return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
}
parser.emitError(parser.getNameLoc())
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index a714d83..0837a65 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -772,25 +772,23 @@ getCachedIntegerType(unsigned width,
}
}
-IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
- return get(width, IntegerType::Signless, context);
+IntegerType IntegerType::get(MLIRContext *context, unsigned width) {
+ return get(context, width, IntegerType::Signless);
}
-IntegerType IntegerType::get(unsigned width,
- IntegerType::SignednessSemantics signedness,
- MLIRContext *context) {
+IntegerType IntegerType::get(MLIRContext *context, unsigned width,
+ IntegerType::SignednessSemantics signedness) {
if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
return Base::get(context, width, signedness);
}
-IntegerType IntegerType::getChecked(unsigned width, Location location) {
- return getChecked(width, IntegerType::Signless, location);
+IntegerType IntegerType::getChecked(Location location, unsigned width) {
+ return getChecked(location, width, IntegerType::Signless);
}
-IntegerType IntegerType::getChecked(unsigned width,
- SignednessSemantics signedness,
- Location location) {
+IntegerType IntegerType::getChecked(Location location, unsigned width,
+ SignednessSemantics signedness) {
if (auto cached =
getCachedIntegerType(width, signedness, location->getContext()))
return cached;
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 1795dd9..1b8d287 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -178,7 +178,7 @@ Operation::Operation(Location location, OperationName name,
if (hasSingleResult)
resultType = resultTypes.front();
else
- resultType = TupleType::get(resultTypes, location->getContext());
+ resultType = TupleType::get(location->getContext(), resultTypes);
}
}
diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp
index 7ec497c..fd7e5b5 100644
--- a/mlir/lib/IR/Value.cpp
+++ b/mlir/lib/IR/Value.cpp
@@ -63,7 +63,7 @@ void Value::setType(Type newType) {
return;
auto newTypes = llvm::to_vector<4>(curTypes);
newTypes[resultNo] = newType;
- owner->resultType = TupleType::get(newTypes, newType.getContext());
+ owner->resultType = TupleType::get(newType.getContext(), newTypes);
}
/// If this value is the result of an Operation, return the operation that
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 11e7e23..7cf0596 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -563,8 +563,8 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
- Identifier::get(dialectName, state.context), symbolData,
- state.context, getEncodedSourceLocation(loc));
+ getEncodedSourceLocation(loc),
+ Identifier::get(dialectName, state.context), symbolData);
});
}
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 1e7cffb..ab7f85a 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -338,7 +338,7 @@ Type Parser::parseNonFunctionType() {
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
consumeToken(Token::inttype);
- return IntegerType::get(width.getValue(), signSemantics, getContext());
+ return IntegerType::get(getContext(), width.getValue(), signSemantics);
}
// float-type
@@ -432,7 +432,7 @@ Type Parser::parseTupleType() {
parseToken(Token::greater, "expected '>' in tuple type"))
return nullptr;
- return TupleType::get(types, getContext());
+ return TupleType::get(getContext(), types);
}
/// Parse a vector type.
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index e0a8420..7f89a41 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -236,7 +236,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
return b.getIntegerAttr(
- IntegerType::get(ci->getType()->getBitWidth(), context),
+ IntegerType::get(context, ci->getType()->getBitWidth()),
ci->getValue());
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
if (c->isString())
diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp
index bf2c143..4b04e9c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization.cpp
@@ -1182,7 +1182,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
// signless semantics for such cases.
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
: IntegerType::SignednessSemantics::Signless;
- typeMap[operands[0]] = IntegerType::get(operands[1], sign, context);
+ typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
} break;
case spirv::Opcode::OpTypeFloat: {
if (operands.size() != 2)
@@ -1345,7 +1345,7 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
if (!isVoidType(returnType)) {
returnTypes = llvm::makeArrayRef(returnType);
}
- typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
+ typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp
index cb41c7b..db00f5e 100644
--- a/mlir/lib/Target/SPIRV/Serialization.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization.cpp
@@ -1267,7 +1267,7 @@ LogicalResult Serializer::prepareBasicType(
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) {
- auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
+ auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID);
diff --git a/mlir/lib/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Transforms/BufferResultsToOutParams.cpp
index e1cb148..87f1a2e 100644
--- a/mlir/lib/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Transforms/BufferResultsToOutParams.cpp
@@ -35,8 +35,8 @@ static void updateFuncOp(FuncOp func,
// Add the new arguments to the function type.
auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
- auto newFunctionType = FunctionType::get(
- newArgTypes, functionType.getResults(), func.getContext());
+ auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
+ functionType.getResults());
func.setType(newFunctionType);
// Transfer the result attributes to arg attributes.
diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp
index d7fa212..e9bc92f 100644
--- a/mlir/lib/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp
@@ -230,9 +230,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
// We create a new function type and modify the function signature with this
// new type.
- newFuncType = FunctionType::get(/*inputs=*/argTypes,
- /*results=*/resultTypes,
- /*context=*/&getContext());
+ newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
+ /*results=*/resultTypes);
}
// Since we update the function signature, it might affect the result types at
@@ -463,9 +462,9 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
continue;
}
- FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
- /*results=*/resultTypes,
- /*context=*/&getContext());
+ FunctionType newFuncType =
+ FunctionType::get(&getContext(), /*inputs=*/inputTypes,
+ /*results=*/resultTypes);
// Setting the new function signature for this external function.
funcOp.setType(newFuncType);
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a1a6b7..7c71164 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2522,8 +2522,8 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
// Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
- funcOp.getContext()));
+ funcOp.setType(FunctionType::get(funcOp.getContext(),
+ result.getConvertedTypes(), newResults));
});
return success();
}
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 7918b7d..b713ae9 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -56,7 +56,7 @@ static FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {},
ArrayRef<Type> args = {}) {
auto &ctx = globalContext();
auto function = FuncOp::create(UnknownLoc::get(&ctx), name,
- FunctionType::get(args, results, &ctx));
+ FunctionType::get(&ctx, args, results));
function.addEntryBlock();
return function;
}
@@ -277,7 +277,7 @@ TEST_FUNC(builder_blocks) {
TEST_FUNC(builder_cond_branch) {
auto f = makeFunction("builder_cond_branch", {},
- {IntegerType::get(1, &globalContext())});
+ {IntegerType::get(&globalContext(), 1)});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
@@ -390,8 +390,8 @@ TEST_FUNC(insertion_in_block) {
TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
using namespace edsc::op;
- auto i1Type = IntegerType::get(1, &globalContext());
- auto i8Type = IntegerType::get(8, &globalContext());
+ auto i1Type = IntegerType::get(&globalContext(), 1);
+ auto i8Type = IntegerType::get(&globalContext(), 8);
auto memrefType = MemRefType::get({}, i1Type, {}, 0);
auto f = makeFunction("zero_and_std_sign_extendi_op", {},
{memrefType, memrefType});
@@ -414,7 +414,7 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
}
TEST_FUNC(operator_or) {
- auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
+ auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
auto f = makeFunction("operator_or", {}, {i1Type, i1Type});
OpBuilder builder(f.getBody());
@@ -435,7 +435,7 @@ TEST_FUNC(operator_or) {
}
TEST_FUNC(operator_and) {
- auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
+ auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
auto f = makeFunction("operator_and", {}, {i1Type, i1Type});
OpBuilder builder(f.getBody());
@@ -536,7 +536,7 @@ TEST_FUNC(fptrunc_f32_bf16) {
TEST_FUNC(select_op_i32) {
using namespace edsc::op;
- auto i32Type = IntegerType::get(32, &globalContext());
+ auto i32Type = IntegerType::get(&globalContext(), 32);
auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 8aec984..bb5ceda 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -653,7 +653,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
}
int64_t dim =
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
- auto type = IntegerType::get(17, context);
+ auto type = IntegerType::get(context, 17);
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 7795b79..849c6c6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -509,7 +509,7 @@ struct TestTypeConverter : public TypeConverter {
// Convert I42 to I43.
if (t.isInteger(42)) {
- results.push_back(IntegerType::get(43, t.getContext()));
+ results.push_back(IntegerType::get(t.getContext(), 43));
return success();
}
diff --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
index 26a0ae1..2dd2c34 100644
--- a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
@@ -69,9 +69,7 @@ struct TestDecomposeCallGraphTypes
Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
- TypeRange TypeRange = inputs.getTypes();
- SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
- TupleType tuple = TupleType::get(types, builder.getContext());
+ TupleType tuple = builder.getTupleType(inputs.getTypes());
Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
return value;
});
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index d7af225..33c6360 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -59,7 +59,7 @@ ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
} else {
tensorType = RankedTensorType::get(shape, eleType);
}
- auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
+ auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
auto indices =
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
auto valuesType = RankedTensorType::get({1}, eleType);
@@ -77,7 +77,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(8, &ctx);
+ IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -95,7 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(8, &ctx);
+ IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
@@ -120,7 +120,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(8, &ctx);
+ IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
@@ -145,7 +145,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>();
- IntegerType convertedType = IntegerType::get(8, &ctx);
+ IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e80e2f6..e1f603a 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -33,7 +33,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace {
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
- IntegerType boolTy = IntegerType::get(1, &context);
+ IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
// Check that splat is automatically detected for boolean values.
@@ -58,7 +58,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
MLIRContext context;
- IntegerType boolTy = IntegerType::get(1, &context);
+ IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
// Check that splat is automatically detected for boolean values.
@@ -81,7 +81,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
TEST(DenseSplatTest, BoolNonSplat) {
MLIRContext context;
- IntegerType boolTy = IntegerType::get(1, &context);
+ IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
// Check that we properly handle non-splat values.
@@ -94,7 +94,7 @@ TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
MLIRContext context;
constexpr size_t intWidth = 19;
- IntegerType intTy = IntegerType::get(intWidth, &context);
+ IntegerType intTy = IntegerType::get(&context, intWidth);
APInt value(intWidth, 10);
testSplat(intTy, value);
@@ -102,7 +102,7 @@ TEST(DenseSplatTest, OddIntSplat) {
TEST(DenseSplatTest, Int32Splat) {
MLIRContext context;
- IntegerType intTy = IntegerType::get(32, &context);
+ IntegerType intTy = IntegerType::get(&context, 32);
int value = 64;
testSplat(intTy, value);
@@ -110,7 +110,7 @@ TEST(DenseSplatTest, Int32Splat) {
TEST(DenseSplatTest, IntAttrSplat) {
MLIRContext context;
- IntegerType intTy = IntegerType::get(85, &context);
+ IntegerType intTy = IntegerType::get(&context, 85);
Attribute value = IntegerAttr::get(intTy, 109);
testSplat(intTy, value);
@@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
TEST(DenseSplatTest, StringSplat) {
MLIRContext context;
Type stringType =
- OpaqueType::get(Identifier::get("test", &context), "string", &context);
+ OpaqueType::get(&context, Identifier::get("test", &context), "string");
StringRef value = "test-string";
testSplat(stringType, value);
}
@@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
TEST(DenseSplatTest, StringAttrSplat) {
MLIRContext context;
Type stringType =
- OpaqueType::get(Identifier::get("test", &context), "string", &context);
+ OpaqueType::get(&context, Identifier::get("test", &context), "string");
Attribute stringAttr = StringAttr::get("test-string", stringType);
testSplat(stringType, stringAttr);
}
@@ -173,7 +173,7 @@ TEST(DenseComplexTest, ComplexFloatSplat) {
TEST(DenseComplexTest, ComplexIntSplat) {
MLIRContext context;
- ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+ ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
@@ -187,7 +187,7 @@ TEST(DenseComplexTest, ComplexAPFloatSplat) {
TEST(DenseComplexTest, ComplexAPIntSplat) {
MLIRContext context;
- ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
+ ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
}
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index c1c8836..0dd9ef9 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -25,7 +25,7 @@ namespace mlir {
/// Helper that returns an example test::TestStruct for testing its
/// implementation.
static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
- auto integerType = mlir::IntegerType::get(32, context);
+ auto integerType = mlir::IntegerType::get(context, 32);
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
auto floatType = mlir::FloatType::getF32(context);
@@ -105,7 +105,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
expectedValues.begin(), expectedValues.end() - 1);
// Add a copy of the last attribute with the wrong type.
- auto i64Type = mlir::IntegerType::get(64, &context);
+ auto i64Type = mlir::IntegerType::get(&context, 64);
auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
auto elementsAttr =
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});