aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp80
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp10
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h21
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp21
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp26
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp24
10 files changed, 116 insertions, 87 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 85f0fd1d..478b6aa 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
- auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
- auto elemSourceType = sourceVectorType.getElementType();
- auto elemBSourceType = sourceBVectorType.getElementType();
- auto elemDestType = destVectorType.getElementType();
-
- if (elemSourceType.isF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isF32())
- return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
- if (elemSourceType.isF16() && elemDestType.isF16())
- return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
- if (elemSourceType.isBF16() && elemDestType.isBF16())
- return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
- if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+
+ if (k == 16) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+ if (chipset.majorVersion == 11) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ }
}
- if (chipset.majorVersion >= 12) {
+ if (chipset.majorVersion < 12)
+ return std::nullopt;
+
+ // gfx12+
+ if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
- bool isWave64 = destVectorType.getNumElements() == 4;
- // This is the ambiguous case. 8 inputs to the wave64 version means that
- // we want the 16x16x32 version, but for wave32 they mean the short form.
- bool has8Inputs = sourceVectorType.getNumElements() == 8;
- if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
- }
+
+ return std::nullopt;
}
- return std::nullopt;
+ if (k == 32) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ return std::nullopt;
+ }
+
+ llvm_unreachable("unhandled WMMA case");
}
namespace {
@@ -1927,16 +1937,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
- const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+ Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
- const Value isEqual =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+ Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, vdst0, v);
// Per `permlane(16|32)` semantics: if the first extracted element equals
// 'v', the result is the second element; otherwise it is the first.
Value vdstNew =
- rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
+ LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 42099aa..12adfe1 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -93,11 +93,11 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
Location loc = op.getLoc();
Value exponentReal =
- rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
- Value zeroImag = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(exponentFloatType));
- Value exponent = rewriter.create<complex::CreateOp>(
- loc, op.getLhs().getType(), exponentReal, zeroImag);
+ arith::SIToFPOp::create(rewriter, loc, exponentFloatType, op.getRhs());
+ Value zeroImag = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = complex::CreateOp::create(
+ rewriter, loc, op.getLhs().getType(), exponentReal, zeroImag);
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
exponent, op.getFastmathAttr());
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 5613e02..0fe7239 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -937,14 +937,14 @@ struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value floatExponent =
- builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
+ arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
Value zero = arith::ConstantOp::create(
builder, elementType, builder.getFloatAttr(elementType, 0.0));
Value complexExponent =
complex::CreateOp::create(builder, type, floatExponent, zero);
- auto pow = builder.create<complex::PowOp>(
- type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
+ auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
+ complexExponent, op.getFastmathAttr());
rewriter.replaceOp(op, pow.getResult());
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 2285d26..eb662a1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
/*isVarArg=*/true);
LLVM::LLVMFuncOp printfDecl =
- getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+ getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
+ printfDecl.setCConv(callingConvention);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateStringConstant(
@@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ call.setCConv(callingConvention);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 66d3bb4..ec74787 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
namespace mlir {
@@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
/// This pass will add a declaration of printf() to the GPUModule if needed
/// and separate out the format strings into global constants. For some
/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
-/// will lower printf calls to appropriate device-side code
+/// will lower printf calls to appropriate device-side code.
+/// However not all backends use the same calling convention and function
+/// naming.
+/// For example, the LLVM SPIRV backend requires calling convention
+/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be
+/// mangled as "_Z6printfPU3AS2Kcz".
+/// Default callingConvention is LLVM::cconv::CConv::C and
+/// funcName is "printf" but they can be customized as needed.
struct GPUPrintfOpToLLVMCallLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
- GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
- int addressSpace = 0)
+ GPUPrintfOpToLLVMCallLowering(
+ const LLVMTypeConverter &converter, int addressSpace = 0,
+ LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
+ StringRef funcName = "printf")
: ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
- addressSpace(addressSpace) {}
+ addressSpace(addressSpace), callingConvention(callingConvention),
+ funcName(funcName) {}
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering
private:
int addressSpace;
+ LLVM::cconv::CConv callingConvention;
+ StringRef funcName;
};
/// Lowering of gpu.printf to a vprintf standard library.
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index c2363a1..25f1e1b 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
- gpu::ThreadIdOp>();
+ gpu::ThreadIdOp, gpu::PrintfOp>();
populateGpuToLLVMSPVConversionPatterns(converter, patterns);
populateGpuMemorySpaceAttributeConversions(converter);
+ patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2,
+ LLVM::cconv::CConv::SPIR_FUNC,
+ "_Z6printfPU3AS2Kcz");
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 852c50c..d64c4d6 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -500,19 +500,19 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
- auto one = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(1));
sinPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
cosPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
}
createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
op);
- auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
- auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+ auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+ auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
maybeTrunc(cosResult, inputType, rewriter)});
@@ -522,14 +522,15 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType()))
- return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
return operand;
}
Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
if (operand.getType() != type)
- return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
return operand;
}
@@ -556,7 +557,7 @@ private:
}
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
- rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
}
};
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 229e40e..7cce324 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -142,8 +142,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
auto structType = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {llvmOperandType, llvmOperandType});
- auto sincosOp = rewriter.create<LLVM::SincosOp>(
- loc, structType, adaptor.getOperand(), attrs.getAttrs());
+ auto sincosOp = LLVM::SincosOp::create(
+ rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 519d9c8..71e3f88 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -394,9 +394,9 @@ private:
if (!convertedType)
return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
- emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
- loc, emitc::LValueType::get(convertedType), noInit);
- rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ auto var = emitc::VariableOp::create(
+ rewriter, loc, emitc::LValueType::get(convertedType), noInit);
+ emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
loopVars.push_back(var);
}
@@ -411,11 +411,11 @@ private:
// Create a global boolean variable to store the loop condition state.
Type i1Type = IntegerType::get(context, 1);
auto globalCondition =
- rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
- emitc::OpaqueAttr::get(context, ""));
+ emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
Value conditionVal = globalCondition.getResult();
- auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+ auto loweredDo = emitc::DoOp::create(rewriter, loc);
// Convert region types to match the target dialect type system.
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
@@ -450,12 +450,12 @@ private:
// Convert scf.condition to condition variable assignment.
Value condition = rewriter.getRemappedValue(condOp.getCondition());
- rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+ emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
// Wrap body region in conditional to preserve scf semantics. Only create
// ifOp if after-region is non-empty.
if (whileOp.getAfterBody()->getOperations().size() > 1) {
- auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+ auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
// Prepare the after region (loop body) for merging.
Block *afterBlock = &whileOp.getAfter().front();
@@ -480,8 +480,8 @@ private:
Block *condBlock = rewriter.createBlock(&condRegion);
rewriter.setInsertionPointToStart(condBlock);
- auto exprOp = rewriter.create<emitc::ExpressionOp>(
- loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ auto exprOp = emitc::ExpressionOp::create(
+ rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
// Set up the expression block to load the condition variable.
@@ -490,12 +490,12 @@ private:
// Load the condition value and yield it as the expression result.
Value cond =
- rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
- rewriter.create<emitc::YieldOp>(loc, cond);
+ emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
+ emitc::YieldOp::create(rewriter, loc, cond);
// Yield the expression as the condition region result.
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<emitc::YieldOp>(loc, exprOp);
+ emitc::YieldOp::create(rewriter, loc, exprOp);
return success();
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 00df14b1..29afdc2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -232,16 +232,16 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ zpAddValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
} else {
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
auto arg1 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
auto arg2 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
zpAddValue =
- rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
}
// The negation can be applied by doing:
@@ -1402,8 +1402,8 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
auto elemType = inputType.getElementType();
auto collapsedType = RankedTensorType::get({}, elemType);
// Emit the collapse op
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
- reassociation);
+ return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
+ reassociation);
}
static llvm::SmallVector<int8_t>
@@ -1443,7 +1443,7 @@ static void setupLinalgGenericOpInputAndIndexingMap(
IntegerAttr intAttr = isShift
? rewriter.getI8IntegerAttr(values.front())
: rewriter.getI32IntegerAttr(values.front());
- constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ constant = arith::ConstantOp::create(rewriter, loc, intAttr);
} else {
auto elementType =
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
@@ -1511,14 +1511,14 @@ static Value getExtendZp(OpBuilder &builder, Type valueTy,
.getResult(0);
}
if (zpTy.isUnsignedInteger()) {
- return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ return arith::ExtUIOp::create(builder, loc, extendType, result);
} else {
- return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ return arith::ExtSIOp::create(builder, loc, extendType, result);
}
}
} else {
- return builder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(extendType, *maybeZp));
+ return arith::ConstantOp::create(builder, loc,
+ IntegerAttr::get(extendType, *maybeZp));
}
return result;
}