aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp8
-rw-r--r--mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp16
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp5
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp11
-rw-r--r--mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp4
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp24
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp15
-rw-r--r--mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp12
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp7
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp52
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp3
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp12
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp30
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp133
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp158
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp75
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp71
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
30 files changed, 472 insertions, 457 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b6f6167..64720bf 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -481,16 +481,16 @@ struct MemoryCounterWaitOpLowering
if (chipset.majorVersion >= 12) {
Location loc = op.getLoc();
if (std::optional<int> ds = adaptor.getDs())
- rewriter.create<ROCDL::WaitDscntOp>(loc, *ds);
+ ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
if (std::optional<int> load = adaptor.getLoad())
- rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load);
+ ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
if (std::optional<int> store = adaptor.getStore())
- rewriter.create<ROCDL::WaitStorecntOp>(loc, *store);
+ ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
if (std::optional<int> exp = adaptor.getExp())
- rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp);
+ ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
rewriter.eraseOp(op);
return success();
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 59b3fe2..515fe5c 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -402,8 +402,8 @@ public:
Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
// Actual cast (may change bitwidth)
- auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
- castDestType, actualOp);
+ auto cast =
+ emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
// Cast to the expected output type
auto result = adaptValueType(cast, rewriter, opReturnType);
@@ -507,8 +507,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -547,8 +547,8 @@ public:
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
- Value arithmeticResult = rewriter.template create<EmitCOp>(
- op.getLoc(), arithmeticType, lhs, rhs);
+ Value arithmeticResult =
+ EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
Value result = adaptValueType(arithmeticResult, rewriter, type);
@@ -748,8 +748,8 @@ public:
}
Value fpCastOperand = adaptor.getIn();
if (actualOperandType != operandType) {
- fpCastOperand = rewriter.template create<emitc::CastOp>(
- castOp.getLoc(), actualOperandType, fpCastOperand);
+ fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
+ actualOperandType, fpCastOperand);
}
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e681..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 30a7170..3edcbb8 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
scf::YieldOp::create(rewriter, loc, acc);
};
- auto size = rewriter
- .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
- loopBody)
+ auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
+ ValueRange(one), loopBody)
.getResult(0);
MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
index c8311eb..5ac838c 100644
--- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp
@@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
return emitError(loc, "Cannot create unreachable terminator for '")
<< parentOp->getName() << "'";
- return builder
- .create<func::ReturnOp>(
- loc, llvm::map_to_vector(funcOp.getResultTypes(),
- [&](Type type) {
- return getUndefValue(loc, builder, type);
- }))
+ return func::ReturnOp::create(
+ builder, loc,
+ llvm::map_to_vector(
+ funcOp.getResultTypes(),
+ [&](Type type) { return getUndefValue(loc, builder, type); }))
.getOperation();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4..56b6181 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f65..c0439a4 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 63eb6c58..3cfbd89 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
auto function = [&] {
if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
return function;
- return OpBuilder::atBlockEnd(module.getBody())
- .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
+ auto builder = OpBuilder::atBlockEnd(module.getBody());
+ return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType);
}();
return LLVM::CallOp::create(builder, loc, function, arguments);
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index a19194e..1817861 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
- rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
- adaptor.getWidth());
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
@@ -559,8 +561,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
builder, loc, builder.getI32Type(),
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
- return builder
- .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
+ return NonUniformOp::create(builder, loc, type, scope, groupOp, arg,
+ clusterSizeValue)
.getResult();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index ecd5b63..2568044 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
- toDynamic
- ? builder
- .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
+ toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
Value source = desc.memRefDescPtr(builder, loc);
LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
if (!toDynamic)
diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
index 5b68eb8..e5496e5 100644
--- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
+++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp
@@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
- ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
+ ret = Op::create(rewriter, loc, std::forward<Args>(args)...);
}
return ret;
}
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index b09afd9..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -22,7 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -32,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
@@ -698,7 +695,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
scf::IfOp ifOp =
scf::IfOp::create(builder, elementType, inputEqZero,
/*addThenBlock=*/true, /*addElseBlock=*/true);
- ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue);
+ auto thenBuilder = ifOp.getThenBodyBuilder();
+ scf::YieldOp::create(thenBuilder, loc, bitWidthValue);
auto elseBuilder =
ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front());
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index a877ad2..1787e0a 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -488,7 +488,12 @@ namespace mlir {
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
- patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
+ patterns
+ .add<CopySignPattern,
+ CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
+ CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
+ CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
+ typeConverter, patterns.getContext());
// GLSL patterns
patterns
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 53a1912..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -575,8 +576,8 @@ private:
Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy,
getTypeConverter()->getIndexType(),
offsetPtr, idxPlusOne);
- return rewriter
- .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
+ return LLVM::LoadOp::create(rewriter, loc,
+ getTypeConverter()->getIndexType(), sizePtr)
.getResult();
}
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5d13353..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -26,13 +26,12 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 240491a..807be7e 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// block. This should be reconsidered if we allow break/continue in SCF.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
+ SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
after, condOp.getArgs(),
continuation, ValueRange());
@@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
- rewriter.replaceOp(whileOp, condOp.getArgs());
+ rewriter.replaceOp(whileOp, args);
return success();
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index aae3271..9b61540 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1493,11 +1493,11 @@ public:
Value extended;
if (op2TypeWidth < dstTypeWidth) {
if (isUnsignedIntegerOrVector(op2Type)) {
- extended = rewriter.template create<LLVM::ZExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
} else {
- extended = rewriter.template create<LLVM::SExtOp>(
- loc, dstType, adaptor.getOperand2());
+ extended =
+ LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
}
} else if (op2TypeWidth == dstTypeWidth) {
extended = adaptor.getOperand2();
@@ -1505,8 +1505,8 @@ public:
return failure();
}
- Value result = rewriter.template create<LLVMOp>(
- loc, dstType, adaptor.getOperand1(), extended);
+ Value result =
+ LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index 8525543..fa9e544 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -36,7 +36,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "shard-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
@@ -177,9 +176,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
auto type = RankedTensorType::get({nSplits, 2}, i64);
Value resHaloSizes =
haloSizes.empty()
- ? rewriter
- .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
- i64)
+ ? tensor::EmptyOp::create(rewriter, loc,
+ std::array<int64_t, 2>{0, 0}, i64)
.getResult()
: tensor::FromElementsOp::create(rewriter, loc, type, haloSizes)
.getResult();
@@ -306,13 +304,11 @@ public:
auto ctx = op.getContext();
Value commWorld =
mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx));
- auto rank =
- rewriter
- .create<mpi::CommRankOp>(
- loc,
- TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
- commWorld)
- .getRank();
+ auto rank = mpi::CommRankOp::create(
+ rewriter, loc,
+ TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
+ commWorld)
+ .getRank();
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
rank);
return success();
@@ -703,10 +699,9 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
// subviews need Index values
for (auto &sz : haloSizes) {
if (auto value = dyn_cast<Value>(sz))
- sz =
- rewriter
- .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
- .getResult();
+ sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(),
+ value)
+ .getResult();
}
// most of the offset/size/stride data is the same for all dims
@@ -758,9 +753,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
// Get the linearized ids of the neighbors (down and up) for the
// given split
- auto tmp = rewriter
- .create<NeighborsLinearIndicesOp>(loc, grid, myMultiIndex,
- splitAxes)
+ auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid,
+ myMultiIndex, splitAxes)
.getResults();
// MPI operates on i32...
Value neighbourIDs[2] = {
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386e..8cd650e 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 5c7c027..0e3de06 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -569,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// to UIToFP.
if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) {
auto unrealizedCast =
- rewriter
- .create<UnrealizedConversionCastOp>(
- loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
- args[0])
+ UnrealizedConversionCastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0])
.getResult(0);
return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
unrealizedCast);
@@ -868,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'linalg.generic' op
auto resultTensor =
- opBuilder
- .create<linalg::GenericOp>(
- loc, outputTensor.getType(), operand, outputTensor, affineMaps,
- getNParallelLoopsAttrs(rank),
- [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
- // Emit 'linalg.yield' op
- linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
- })
+ linalg::GenericOp::create(
+ opBuilder, loc, outputTensor.getType(), operand, outputTensor,
+ affineMaps, getNParallelLoopsAttrs(rank),
+ [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
+ // Emit 'linalg.yield' op
+ linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
+ })
.getResult(0);
// Cast to original operand type if necessary
@@ -1155,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);
// First fill the output buffer with the init value.
- auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
- dynDims)
- .getResult();
+ auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
+ .getResult();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
@@ -1167,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
op, "No initial value found for reduction operation");
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
outputs.push_back(filledTensor);
bool isNanIgnoreMode = false;
@@ -1186,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto trueAttr = rewriter.getBoolAttr(true);
auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
auto emptyBoolTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
- dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ trueValue.getType(), dynDims)
.getResult();
auto allResultsNaNTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{trueValue},
- ValueRange{emptyBoolTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{trueValue},
+ ValueRange{emptyBoolTensor})
.result();
// Note that because the linalg::ReduceOp has two variadic arguments
// (inputs and outputs) and it has the SameVariadicOperandSize trait we
@@ -1261,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
auto nanFilledTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{nanValue},
- ValueRange{emptyNanTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
+ ValueRange{emptyNanTensor})
.result();
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape,
+ resultTy.getElementType(), dynDims)
.getResult();
// Do a selection between the tensors akin to:
@@ -1503,12 +1494,11 @@ public:
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(
- nestedLoc,
- nestedBuilder.getIntegerType(
- valueTy.getIntOrFloatBitWidth()),
- value)
+ value = UnrealizedConversionCastOp::create(
+ nestedBuilder, nestedLoc,
+ nestedBuilder.getIntegerType(
+ valueTy.getIntOrFloatBitWidth()),
+ value)
.getResult(0);
}
if (valueTy.getIntOrFloatBitWidth() < 32) {
@@ -1557,9 +1547,8 @@ public:
}
if (outIntType.isUnsignedInteger()) {
- value = nestedBuilder
- .create<UnrealizedConversionCastOp>(nestedLoc,
- outIntType, value)
+ value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
+ outIntType, value)
.getResult(0);
}
linalg::YieldOp::create(nestedBuilder, loc, value);
@@ -2095,10 +2084,9 @@ public:
Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
// First fill the output buffer with the init value.
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, inputTy.getShape(),
- inputTy.getElementType(),
- ArrayRef<Value>({dynDims}))
+ auto emptyTensor = tensor::EmptyOp::create(
+ rewriter, loc, inputTy.getShape(),
+ inputTy.getElementType(), ArrayRef<Value>({dynDims}))
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
@@ -2241,23 +2229,22 @@ public:
}
// First fill the output buffer for the index.
- auto emptyTensorIdx = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- outElementTy, dynDims)
- .getResult();
+ auto emptyTensorIdx =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ outElementTy, dynDims)
+ .getResult();
auto fillValueIdx = arith::ConstantOp::create(
rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
auto filledTensorIdx =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
- ValueRange{emptyTensorIdx})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx},
+ ValueRange{emptyTensorIdx})
.result();
// Second fill the output buffer for the running max.
- auto emptyTensorMax = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- inElementTy, dynDims)
- .getResult();
+ auto emptyTensorMax =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
+ dynDims)
+ .getResult();
auto fillValueMaxAttr =
createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
@@ -2268,9 +2255,8 @@ public:
auto fillValueMax =
arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
auto filledTensorMax =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
- ValueRange{emptyTensorMax})
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax},
+ ValueRange{emptyTensorMax})
.result();
// We need to reduce along the arg-max axis, with parallel operations along
@@ -2371,9 +2357,8 @@ public:
auto loc = op.getLoc();
auto emptyTensor =
- rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
- dynamicDims)
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynamicDims)
.getResult();
SmallVector<AffineMap, 2> affineMaps = {
@@ -2448,10 +2433,10 @@ public:
}
}
- auto emptyTensor = rewriter
- .create<tensor::EmptyOp>(loc, resultTy.getShape(),
- resultElementTy, dynDims)
- .getResult();
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
+ resultElementTy, dynDims)
+ .getResult();
SmallVector<AffineMap, 2> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
@@ -2585,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
- auto filledTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{fillValue},
- ValueRange{emptyTensor})
- .result();
+ auto filledTensor =
+ linalg::FillOp::create(rewriter, loc, ValueRange{fillValue},
+ ValueRange{emptyTensor})
+ .result();
return filledTensor;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3a20524..da1fb20 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) {
ShapedType resultTy = cast<ShapedType>(conv.getType());
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
- }
- Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
- linalg::YieldOp::create(builder, loc, added);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, conv}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ arith::ExtSIOp::create(builder, loc, resType, biasVal);
+ }
+ Value added =
+ arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
+ })
.getResult(0);
}
@@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
// Build the broadcast-like operation as a linalg.generic.
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({source}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal =
- resultTy.getElementType().isFloat()
- ? arith::ExtFOp::create(builder, loc, resType, biasVal)
- .getResult()
- : arith::ExtSIOp::create(builder, loc, resType, biasVal)
- .getResult();
- }
- linalg::YieldOp::create(builder, loc, biasVal);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({source}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
+ .getResult()
+ : arith::ExtSIOp::create(builder, loc, resType,
+ biasVal)
+ .getResult();
+ }
+ linalg::YieldOp::create(builder, loc, biasVal);
+ })
.getResult(0);
}
@@ -397,21 +398,19 @@ public:
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
- Value conv =
- rewriter
- .create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ Value conv = LinalgConvQOp::create(
+ rewriter, loc, resultTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ ->getResult(0);
rewriter.replaceOp(op, conv);
return success();
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, accTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ Value conv = LinalgConvOp::create(
+ rewriter, loc, accTy, ValueRange{input, weight},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -529,9 +528,8 @@ public:
Value emptyTensor = tensor::EmptyOp::create(
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
Value biasEmptyTensor = tensor::EmptyOp::create(
@@ -544,10 +542,9 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
if (hasNullZps) {
- Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
+ rewriter, loc, linalgConvTy, ValueRange{input, weight},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -565,22 +562,20 @@ public:
rewriter, loc, resultTy, conv, reassociationMap);
Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, convReshape}),
- biasEmptyTensor, indexingMaps,
- getNParallelLoopsAttrs(resultRank),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added;
- if (llvm::isa<FloatType>(inputETy))
- added = arith::AddFOp::create(nestedBuilder, loc, args[0],
- args[1]);
- else
- added = arith::AddIOp::create(nestedBuilder, loc, args[0],
- args[1]);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
- })
+ linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, convReshape}),
+ biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ Value added;
+ if (llvm::isa<FloatType>(inputETy))
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ else
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
+ })
.getResult(0);
rewriter.replaceOp(op, result);
} else {
@@ -588,12 +583,11 @@ public:
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
- Value conv =
- rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
+ rewriter, loc, linalgConvTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ .getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = tensor::CollapseShapeOp::create(
@@ -639,9 +633,8 @@ public:
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
outputTy.getElementType(), filteredDims);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
@@ -910,20 +903,18 @@ public:
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{poolEmptyTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
+ ValueRange{poolEmptyTensor})
.result();
Value fakeWindowDims =
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
- Value poolingOp = rewriter
- .create<linalg::PoolingNhwcSumOp>(
- loc, ArrayRef<Type>{accTy},
- ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr)
+ Value poolingOp = linalg::PoolingNhwcSumOp::create(
+ rewriter, loc, ArrayRef<Type>{accTy},
+ ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr)
.getResult(0);
// Normalize the summed value by the number of elements grouped in each
@@ -1050,10 +1041,9 @@ public:
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
- rewriter
- .create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), poolVal, multiplier, shift,
- rewriter.getStringAttr("SINGLE_ROUND"))
+ tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
+ shift, rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 77aab85..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -482,14 +481,12 @@ struct CombineTransferReadOpTranspose final
permutationMap.compose(transferReadOp.getPermutationMap());
auto loc = op.getLoc();
- Value result =
- rewriter
- .create<vector::TransferReadOp>(
- loc, resultType, transferReadOp.getBase(),
- transferReadOp.getIndices(), AffineMapAttr::get(newMap),
- transferReadOp.getPadding(), transferReadOp.getMask(),
- transferReadOp.getInBoundsAttr())
- .getResult();
+ Value result = vector::TransferReadOp::create(
+ rewriter, loc, resultType, transferReadOp.getBase(),
+ transferReadOp.getIndices(), AffineMapAttr::get(newMap),
+ transferReadOp.getPadding(), transferReadOp.getMask(),
+ transferReadOp.getInBoundsAttr())
+ .getResult();
// Fuse through the integer extend op.
if (extOp) {
@@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -585,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -599,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -615,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -643,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -733,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -747,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -936,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1134,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1152,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1170,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1193,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1246,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9cd491c..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -29,7 +29,9 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/Casting.h"
+
#include <optional>
using namespace mlir;
@@ -1068,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1204,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b1af5f0..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,