aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp10
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp10
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h21
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp21
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp26
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp24
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp21
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp7
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp137
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp11
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp13
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp87
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt13
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp26
-rw-r--r--mlir/lib/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp14
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp28
-rw-r--r--mlir/lib/Dialect/OpenMP/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp170
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt14
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp445
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp104
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp94
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp37
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp218
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp15
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp19
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp31
-rw-r--r--mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt2
-rw-r--r--mlir/lib/IR/Diagnostics.cpp6
-rw-r--r--mlir/lib/RegisterAllPasses.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp75
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp8
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp6
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp20
47 files changed, 1475 insertions, 386 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 85f0fd1d..9b15435 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1927,16 +1927,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
- const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+ Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
- const Value isEqual =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+ Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, vdst0, v);
// Per `permlane(16|32)` semantics: if the first extracted element equals
// 'v', the result is the second element; otherwise it is the first.
Value vdstNew =
- rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
+ LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 42099aa..12adfe1 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -93,11 +93,11 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
Location loc = op.getLoc();
Value exponentReal =
- rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
- Value zeroImag = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(exponentFloatType));
- Value exponent = rewriter.create<complex::CreateOp>(
- loc, op.getLhs().getType(), exponentReal, zeroImag);
+ arith::SIToFPOp::create(rewriter, loc, exponentFloatType, op.getRhs());
+ Value zeroImag = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = complex::CreateOp::create(
+ rewriter, loc, op.getLhs().getType(), exponentReal, zeroImag);
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
exponent, op.getFastmathAttr());
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 5613e02..0fe7239 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -937,14 +937,14 @@ struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value floatExponent =
- builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
+ arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
Value zero = arith::ConstantOp::create(
builder, elementType, builder.getFloatAttr(elementType, 0.0));
Value complexExponent =
complex::CreateOp::create(builder, type, floatExponent, zero);
- auto pow = builder.create<complex::PowOp>(
- type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
+ auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
+ complexExponent, op.getFastmathAttr());
rewriter.replaceOp(op, pow.getResult());
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 2285d26..eb662a1 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
/*isVarArg=*/true);
LLVM::LLVMFuncOp printfDecl =
- getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+ getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
+ printfDecl.setCConv(callingConvention);
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateStringConstant(
@@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
printfArgs.push_back(stringStart);
printfArgs.append(argsRange.begin(), argsRange.end());
- LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+ call.setCConv(callingConvention);
rewriter.eraseOp(gpuPrintfOp);
return success();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 66d3bb4..ec74787 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -10,6 +10,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
namespace mlir {
@@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
/// This pass will add a declaration of printf() to the GPUModule if needed
/// and separate out the format strings into global constants. For some
/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
-/// will lower printf calls to appropriate device-side code
+/// will lower printf calls to appropriate device-side code.
+/// However not all backends use the same calling convention and function
+/// naming.
+/// For example, the LLVM SPIRV backend requires calling convention
+/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be
+/// mangled as "_Z6printfPU3AS2Kcz".
+/// Default callingConvention is LLVM::cconv::CConv::C and
+/// funcName is "printf" but they can be customized as needed.
struct GPUPrintfOpToLLVMCallLowering
: public ConvertOpToLLVMPattern<gpu::PrintfOp> {
- GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
- int addressSpace = 0)
+ GPUPrintfOpToLLVMCallLowering(
+ const LLVMTypeConverter &converter, int addressSpace = 0,
+ LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
+ StringRef funcName = "printf")
: ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
- addressSpace(addressSpace) {}
+ addressSpace(addressSpace), callingConvention(callingConvention),
+ funcName(funcName) {}
LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering
private:
int addressSpace;
+ LLVM::cconv::CConv callingConvention;
+ StringRef funcName;
};
/// Lowering of gpu.printf to a vprintf standard library.
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index c2363a1..25f1e1b 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final
gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
- gpu::ThreadIdOp>();
+ gpu::ThreadIdOp, gpu::PrintfOp>();
populateGpuToLLVMSPVConversionPatterns(converter, patterns);
populateGpuMemorySpaceAttributeConversions(converter);
+ patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2,
+ LLVM::cconv::CConv::SPIR_FUNC,
+ "_Z6printfPU3AS2Kcz");
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 852c50c..d64c4d6 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -500,19 +500,19 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
- auto one = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(1));
sinPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
cosPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
}
createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
op);
- auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
- auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+ auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+ auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
maybeTrunc(cosResult, inputType, rewriter)});
@@ -522,14 +522,15 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType()))
- return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
return operand;
}
Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
if (operand.getType() != type)
- return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
return operand;
}
@@ -556,7 +557,7 @@ private:
}
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
- rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
}
};
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 229e40e..7cce324 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -142,8 +142,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
auto structType = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {llvmOperandType, llvmOperandType});
- auto sincosOp = rewriter.create<LLVM::SincosOp>(
- loc, structType, adaptor.getOperand(), attrs.getAttrs());
+ auto sincosOp = LLVM::SincosOp::create(
+ rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 519d9c8..71e3f88 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -394,9 +394,9 @@ private:
if (!convertedType)
return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
- emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
- loc, emitc::LValueType::get(convertedType), noInit);
- rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ auto var = emitc::VariableOp::create(
+ rewriter, loc, emitc::LValueType::get(convertedType), noInit);
+ emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
loopVars.push_back(var);
}
@@ -411,11 +411,11 @@ private:
// Create a global boolean variable to store the loop condition state.
Type i1Type = IntegerType::get(context, 1);
auto globalCondition =
- rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
- emitc::OpaqueAttr::get(context, ""));
+ emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
Value conditionVal = globalCondition.getResult();
- auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+ auto loweredDo = emitc::DoOp::create(rewriter, loc);
// Convert region types to match the target dialect type system.
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
@@ -450,12 +450,12 @@ private:
// Convert scf.condition to condition variable assignment.
Value condition = rewriter.getRemappedValue(condOp.getCondition());
- rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+ emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
// Wrap body region in conditional to preserve scf semantics. Only create
// ifOp if after-region is non-empty.
if (whileOp.getAfterBody()->getOperations().size() > 1) {
- auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+ auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
// Prepare the after region (loop body) for merging.
Block *afterBlock = &whileOp.getAfter().front();
@@ -480,8 +480,8 @@ private:
Block *condBlock = rewriter.createBlock(&condRegion);
rewriter.setInsertionPointToStart(condBlock);
- auto exprOp = rewriter.create<emitc::ExpressionOp>(
- loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ auto exprOp = emitc::ExpressionOp::create(
+ rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
// Set up the expression block to load the condition variable.
@@ -490,12 +490,12 @@ private:
// Load the condition value and yield it as the expression result.
Value cond =
- rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
- rewriter.create<emitc::YieldOp>(loc, cond);
+ emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
+ emitc::YieldOp::create(rewriter, loc, cond);
// Yield the expression as the condition region result.
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<emitc::YieldOp>(loc, exprOp);
+ emitc::YieldOp::create(rewriter, loc, exprOp);
return success();
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 00df14b1..29afdc2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -232,16 +232,16 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ zpAddValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
} else {
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
auto arg1 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
auto arg2 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
zpAddValue =
- rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
}
// The negation can be applied by doing:
@@ -1402,8 +1402,8 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
auto elemType = inputType.getElementType();
auto collapsedType = RankedTensorType::get({}, elemType);
// Emit the collapse op
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
- reassociation);
+ return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
+ reassociation);
}
static llvm::SmallVector<int8_t>
@@ -1443,7 +1443,7 @@ static void setupLinalgGenericOpInputAndIndexingMap(
IntegerAttr intAttr = isShift
? rewriter.getI8IntegerAttr(values.front())
: rewriter.getI32IntegerAttr(values.front());
- constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ constant = arith::ConstantOp::create(rewriter, loc, intAttr);
} else {
auto elementType =
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
@@ -1511,14 +1511,14 @@ static Value getExtendZp(OpBuilder &builder, Type valueTy,
.getResult(0);
}
if (zpTy.isUnsignedInteger()) {
- return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ return arith::ExtUIOp::create(builder, loc, extendType, result);
} else {
- return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ return arith::ExtSIOp::create(builder, loc, extendType, result);
}
}
} else {
- return builder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(extendType, *maybeZp));
+ return arith::ConstantOp::create(builder, loc,
+ IntegerAttr::get(extendType, *maybeZp));
}
return result;
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
index 316721b..60ae78b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
@@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> {
const std::function<unsigned(AffineForOp)> getUnrollFactor;
LoopUnroll() : getUnrollFactor(nullptr) {}
- LoopUnroll(const LoopUnroll &other)
-
- = default;
+ LoopUnroll(const LoopUnroll &other) = default;
explicit LoopUnroll(
std::optional<unsigned> unrollFactor = std::nullopt,
- bool unrollUpToFactor = false, bool unrollFull = false,
+ bool unrollUpToFactor = false,
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
: getUnrollFactor(getUnrollFactor) {
if (unrollFactor)
this->unrollFactor = *unrollFactor;
this->unrollUpToFactor = unrollUpToFactor;
- this->unrollFull = unrollFull;
}
void runOnOperation() override;
@@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f,
}
void LoopUnroll::runOnOperation() {
+ if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
+ emitError(UnknownLoc::get(&getContext()),
+ "Invalid option: 'unroll-factor' should be greater than 0 or "
+ "equal to -1");
+ return signalPassFailure();
+ }
FunctionOpInterface func = getOperation();
if (func.isExternal())
return;
- if (unrollFull && unrollFullThreshold.hasValue()) {
+ if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) {
// Store short loops as we walk.
SmallVector<AffineForOp, 4> loops;
@@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
return loopUnrollByFactor(forOp, getUnrollFactor(forOp),
/*annotateFn=*/nullptr, cleanUpUnroll);
// Unroll completely if full loop unroll was specified.
- if (unrollFull)
+ if (unrollFactor.getValue() == -1)
return loopUnrollFull(forOp);
// Otherwise, unroll by the given unroll factor.
if (unrollUpToFactor)
@@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::affine::createLoopUnrollPass(
- int unrollFactor, bool unrollUpToFactor, bool unrollFull,
+ int unrollFactor, bool unrollUpToFactor,
const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
return std::make_unique<LoopUnroll>(
unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
- unrollUpToFactor, unrollFull, getUnrollFactor);
+ unrollUpToFactor, getUnrollFactor);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index a6159ee..f0ddb50 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -14,13 +14,6 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
-namespace mlir {
-namespace bufferization {
-#define GEN_PASS_DEF_TENSORCOPYINSERTION
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
-} // namespace bufferization
-} // namespace mlir
-
using namespace mlir;
using namespace mlir::bufferization;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2a8c330..f0de4db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff095..37a45d4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
MLIRPass
MLIRTransforms
MLIRNVVMDialect
+ MLIROpenMPDialect
)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9a8a63e..794dda9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
if (!ShapedType::isDynamic(dim))
continue;
- Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
- auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ Value cst =
+ arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
+ auto dimOp =
+ tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
preservedOps.insert(dimOp);
dynamicDims.push_back(dimOp);
}
- auto allocation = rewriter.create<bufferization::AllocTensorOp>(
- tensor.getLoc(), type, dynamicDims);
+ auto allocation = bufferization::AllocTensorOp::create(
+ rewriter, tensor.getLoc(), type, dynamicDims);
// Set memory space if provided.
if (getMemorySpaceAttr())
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
@@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
// Only insert a materialization (typically bufferizes to a copy) when the
// value may be read from.
if (needsMaterialization) {
- auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
- tensor.getLoc(), tensor, allocated);
+ auto copy = bufferization::MaterializeInDestinationOp::create(
+ rewriter, tensor.getLoc(), tensor, allocated);
preservedOps.insert(copy);
promoted.push_back(copy.getResult());
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a..5e10ba3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,33 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = index::CmpOp::create(
+ builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate =
+ arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +138,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b7..c551fba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
@@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 11400de..a15bf89 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -59,6 +59,17 @@ struct DimOpInterface
}
};
+struct ExpandShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+ memref::ExpandShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto expandOp = cast<memref::ExpandShapeOp>(op);
+ assert(value == expandOp.getResult() && "invalid value");
+ cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ }
+};
+
struct GetGlobalOpInterface
: public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
+ *ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a..bd02516 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
-/// Replace `base, offset, sizes, strides =
-/// extract_strided_metadata(
-/// cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-/// ? dstTy.srcOffset
-/// : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-/// !srcSize.isDynamic()
-/// ? srcSize
-// : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-/// !srcStrides.isDynamic()
-/// ? srcStrides
-/// : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult
- matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- PatternRewriter &rewriter) const override {
- Value source = extractStridedMetadataOp.getSource();
- auto castOp = source.getDefiningOp<memref::CastOp>();
- if (!castOp)
- return failure();
-
- Location loc = extractStridedMetadataOp.getLoc();
- // Check if the source is suitable for extract_strided_metadata.
- SmallVector<Type> inferredReturnTypes;
- if (failed(extractStridedMetadataOp.inferReturnTypes(
- rewriter.getContext(), loc, {castOp.getSource()},
- /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
- inferredReturnTypes)))
- return rewriter.notifyMatchFailure(castOp,
- "cast source's type is incompatible");
-
- auto memrefType = cast<MemRefType>(source.getType());
- unsigned rank = memrefType.getRank();
- SmallVector<OpFoldResult> results;
- results.resize_for_overwrite(rank * 2 + 2);
-
- auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
- rewriter, loc, castOp.getSource());
-
- // Register the base_buffer.
- results[0] = newExtractStridedMetadata.getBaseBuffer();
-
- auto getConstantOrValue = [&rewriter](int64_t constant,
- OpFoldResult ofr) -> OpFoldResult {
- return ShapedType::isStatic(constant)
- ? OpFoldResult(rewriter.getIndexAttr(constant))
- : ofr;
- };
-
- auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
- assert(sourceStrides.size() == rank && "unexpected number of strides");
-
- // Register the new offset.
- results[1] =
- getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + rank;
- ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
- SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
- SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
- for (unsigned i = 0; i < rank; ++i) {
- results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
- results[strideStartIdx + i] =
- getConstantOrValue(sourceStrides[i], strides[i]);
- }
- rewriter.replaceOp(extractStridedMetadataOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, results));
- return success();
- }
-};
-
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..f305068
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIROpenACCAnalysis
+ OpenACCSupport.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+ MLIRSupport
+)
+
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
new file mode 100644
index 0000000..f6b4534
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -0,0 +1,26 @@
+//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the OpenACCSupport analysis interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+
+namespace mlir {
+namespace acc {
+
+std::string OpenACCSupport::getVariableName(Value v) {
+ if (impl)
+ return impl->getVariableName(v);
+ return acc::getVariableName(v);
+}
+
+} // namespace acc
+} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 7117520..e8a916e 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Utils)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ca0100..ca46629 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() {
}
//===----------------------------------------------------------------------===//
+// FirstprivateMapInitialOp
+//===----------------------------------------------------------------------===//
+LogicalResult acc::FirstprivateMapInitialOp::verify() {
+ if (getDataClause() != acc::DataClause::acc_firstprivate)
+ return emitError("data clause associated with firstprivate operation must "
+ "match its intent");
+ if (failed(checkVarAndVarType(*this)))
+ return failure();
+ if (failed(checkNoModifier(*this)))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
LogicalResult acc::ReductionOp::verify() {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1223325..89adda82 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/TypeSwitch.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) {
pointerLikeTy.getElementType());
return typeCategory;
}
+
+std::string mlir::acc::getVariableName(mlir::Value v) {
+ Value current = v;
+
+ // Walk through view operations until a name is found or can't go further
+ while (Operation *definingOp = current.getDefiningOp()) {
+ // Check for `acc.var_name` attribute
+ if (auto varNameAttr =
+ definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName()))
+ return varNameAttr.getName().str();
+
+ // If it is a data entry operation, get name via getVarName
+ if (isa<ACC_DATA_ENTRY_OPS>(definingOp))
+ if (auto name = acc::getVarName(definingOp))
+ return name->str();
+
+ // If it's a view operation, continue to the source
+ if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) {
+ current = viewOp.getViewSource();
+ continue;
+ }
+
+ break;
+ }
+
+ return "";
+}
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d34..f3c02da 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index fd4cabbad..1b069c6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -32,7 +32,6 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
-#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/InterleavedRange.h"
#include <cstddef>
#include <iterator>
@@ -1737,10 +1736,10 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
// Parser, printer and verifier for Target
//===----------------------------------------------------------------------===//
-// Helper function to get bitwise AND of `value` and 'flag'
-static uint64_t mapTypeToBitFlag(uint64_t value,
- llvm::omp::OpenMPOffloadMappingFlags flag) {
- return value & llvm::to_underlying(flag);
+// Helper function to get bitwise AND of `value` and 'flag' then return it as a
+// boolean
+static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
+ return (value & flag) == flag;
}
/// Parses a map_entries map type from a string format back into its numeric
@@ -1748,10 +1747,9 @@ static uint64_t mapTypeToBitFlag(uint64_t value,
///
/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
-static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
- llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
-
+static ParseResult parseMapClause(OpAsmParser &parser,
+ ClauseMapFlagsAttr &mapType) {
+ ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
// This simply verifies the correct keyword is read in, the
// keyword itself is stored inside of the operation
auto parseTypeAndMod = [&]() -> ParseResult {
@@ -1760,35 +1758,64 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
return failure();
if (mapTypeMod == "always")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+ mapTypeBits |= ClauseMapFlags::always;
if (mapTypeMod == "implicit")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+ mapTypeBits |= ClauseMapFlags::implicit;
if (mapTypeMod == "ompx_hold")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
+ mapTypeBits |= ClauseMapFlags::ompx_hold;
if (mapTypeMod == "close")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
+ mapTypeBits |= ClauseMapFlags::close;
if (mapTypeMod == "present")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
+ mapTypeBits |= ClauseMapFlags::present;
if (mapTypeMod == "to")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+ mapTypeBits |= ClauseMapFlags::to;
if (mapTypeMod == "from")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ mapTypeBits |= ClauseMapFlags::from;
if (mapTypeMod == "tofrom")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+ mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
if (mapTypeMod == "delete")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+ mapTypeBits |= ClauseMapFlags::del;
+
+ if (mapTypeMod == "storage")
+ mapTypeBits |= ClauseMapFlags::storage;
if (mapTypeMod == "return_param")
- mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapTypeBits |= ClauseMapFlags::return_param;
+
+ if (mapTypeMod == "private")
+ mapTypeBits |= ClauseMapFlags::priv;
+
+ if (mapTypeMod == "literal")
+ mapTypeBits |= ClauseMapFlags::literal;
+
+ if (mapTypeMod == "attach")
+ mapTypeBits |= ClauseMapFlags::attach;
+
+ if (mapTypeMod == "attach_always")
+ mapTypeBits |= ClauseMapFlags::attach_always;
+
+ if (mapTypeMod == "attach_none")
+ mapTypeBits |= ClauseMapFlags::attach_none;
+
+ if (mapTypeMod == "attach_auto")
+ mapTypeBits |= ClauseMapFlags::attach_auto;
+
+ if (mapTypeMod == "ref_ptr")
+ mapTypeBits |= ClauseMapFlags::ref_ptr;
+
+ if (mapTypeMod == "ref_ptee")
+ mapTypeBits |= ClauseMapFlags::ref_ptee;
+
+ if (mapTypeMod == "ref_ptr_ptee")
+ mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
return success();
};
@@ -1796,9 +1823,8 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
if (parser.parseCommaSeparatedList(parseTypeAndMod))
return failure();
- mapType = parser.getBuilder().getIntegerAttr(
- parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
- llvm::to_underlying(mapTypeBits));
+ mapType =
+ parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
return success();
}
@@ -1806,60 +1832,62 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
/// Prints a map_entries map type from its numeric value out into its string
/// format.
static void printMapClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr mapType) {
- uint64_t mapTypeBits = mapType.getUInt();
-
- bool emitAllocRelease = true;
+ ClauseMapFlagsAttr mapType) {
llvm::SmallVector<std::string, 4> mapTypeStrs;
+ ClauseMapFlags mapFlags = mapType.getValue();
// handling of always, close, present placed at the beginning of the string
// to aid readability
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
mapTypeStrs.push_back("always");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
mapTypeStrs.push_back("implicit");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
mapTypeStrs.push_back("ompx_hold");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
mapTypeStrs.push_back("close");
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
mapTypeStrs.push_back("present");
// special handling of to/from/tofrom/delete and release/alloc, release +
// alloc are the abscense of one of the other flags, whereas tofrom requires
// both the to and from flag to be set.
- bool to = mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
- bool from = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- if (to && from) {
- emitAllocRelease = false;
+ bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
+ bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
+
+ if (to && from)
mapTypeStrs.push_back("tofrom");
- } else if (from) {
- emitAllocRelease = false;
+ else if (from)
mapTypeStrs.push_back("from");
- } else if (to) {
- emitAllocRelease = false;
+ else if (to)
mapTypeStrs.push_back("to");
- }
- if (mapTypeToBitFlag(mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
- emitAllocRelease = false;
+
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
mapTypeStrs.push_back("delete");
- }
- if (mapTypeToBitFlag(
- mapTypeBits,
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
- emitAllocRelease = false;
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
mapTypeStrs.push_back("return_param");
- }
- if (emitAllocRelease)
- mapTypeStrs.push_back("exit_release_or_enter_alloc");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
+ mapTypeStrs.push_back("storage");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
+ mapTypeStrs.push_back("private");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
+ mapTypeStrs.push_back("literal");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
+ mapTypeStrs.push_back("attach");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
+ mapTypeStrs.push_back("attach_always");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none))
+ mapTypeStrs.push_back("attach_none");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
+ mapTypeStrs.push_back("attach_auto");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
+ mapTypeStrs.push_back("ref_ptr");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
+ mapTypeStrs.push_back("ref_ptee");
+ if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
+ mapTypeStrs.push_back("ref_ptr_ptee");
+ if (mapFlags == ClauseMapFlags::none)
+ mapTypeStrs.push_back("none");
for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
p << mapTypeStrs[i];
@@ -1963,21 +1991,15 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
return emitError(op->getLoc(), "missing map operation");
if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
- uint64_t mapTypeBits = mapInfoOp.getMapType();
-
- bool to = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
- bool from = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
- bool del = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
-
- bool always = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
- bool close = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
- bool implicit = mapTypeToBitFlag(
- mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+ mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
+
+ bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
+ bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
+ bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
+
+ bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
+ bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
+ bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
return emitError(op->getLoc(),
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..b9b8eda
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPOffloadPrivatizationPrepare.cpp
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIROpenMPDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
new file mode 100644
index 0000000..a9125ec
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -0,0 +1,445 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+ : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+
+ // In this pass, we make host-allocated privatized variables persist for
+ // deferred target tasks by copying them to the heap. Once the target task
+ // is done, this heap memory is freed. Since all of this happens on the host
+ // we can skip device modules.
+ auto offloadModuleInterface =
+ dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+ return;
+
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return;
+ IRRewriter rewriter(&getContext());
+ OperandRange privateVars = targetOp.getPrivateVars();
+ SmallVector<mlir::Value> newPrivVars;
+ Value fakeDependVar;
+ omp::TaskOp cleanupTaskOp;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ Value privVar = std::get<0>(privVarSymPair);
+ Attribute privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+ auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // For deferred target tasks (!$omp target nowait), we need to keep
+ // a copy of the original, i.e. host variable being privatized so
+ // that it is available when the target task is eventually executed.
+ // We do this by first allocating as much heap memory as is needed by
+ // the original variable. Then, we use the init and copy regions of the
+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
+ // allocated copy.
+ // After the target task is done, we need to use the dealloc region
+ // of the privatizer to clean up everything. We also need to free
+ // the heap memory we allocated. But due to the deferred nature
+ // of the target task, we cannot simply deallocate right after the
+ // omp.target operation else we may end up freeing memory before
+ // its eventual use by the target task. So, we create a dummy
+ // dependence between the target task and new omp.task. In the omp.task,
+ // we do all the cleanup. So, we end up with the following structure
+ //
+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
+ // ...
+ // omp.terminator
+ // }
+ // omp.task depend(in: fakeDependVar) {
+ // /*cleanup_code*/
+ // omp.terminator
+ // }
+ // fakeDependVar is the address of the first heap-allocated copy of the
+ // host variable being privatized.
+
+ bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+ // should have mapped the pointer to the boxchar so use that as varPtr.
+ Value varPtr = mapInfoOp.getVarPtr();
+ Type varType = mapInfoOp.getVarType();
+ bool isPrivatizedByValue =
+ !isa<LLVM::LLVMPointerType>(privVar.getType());
+
+ assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+ Value heapMem =
+ allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+ if (!heapMem)
+ targetOp.emitError(
+ "Unable to allocate heap memory when trying to move "
+ "a private variable out of the stack and into the "
+ "heap for use by a deferred target task");
+
+ if (needsCleanupTask && !fakeDependVar)
+ fakeDependVar = heapMem;
+
+ // The types of private vars should match before and after the
+ // transformation. In particular, if the type is a pointer,
+ // simply record the newly allocated malloc location as the
+ // new private variable. If, however, the type is not a pointer
+ // then, we need to load the value from the newly allocated
+ // location. We'll insert that load later after we have updated
+ // the malloc'd location with the contents of the original
+ // variable.
+ if (!isPrivatizedByValue)
+ newPrivVars.push_back(heapMem);
+
+ // We now need to copy the original private variable into the newly
+ // allocated location in the heap.
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ DenseSet<Operation *> users;
+ if (varPtrDefiningOp) {
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+ } else {
+ auto blockArg = cast<BlockArgument>(varPtr);
+ users.insert(blockArg.user_begin(), blockArg.user_end());
+ }
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOp);
+ for (auto member : mapInfoOp.getMembers()) {
+ omp::MapInfoOp memberMap =
+ cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (usesVarPtr(memberMap))
+ chainOfOps.push_back(memberMap);
+ if (memberMap.getVarPtrPtr()) {
+ Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
+ if (defOp && usesVarPtr(defOp))
+ chainOfOps.push_back(defOp);
+ }
+ }
+
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ return dom.dominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+
+ Operation *firstOp = chainOfOps.front();
+ Location loc = firstOp->getLoc();
+
+ // Create a llvm.func for 'region' that is marked always_inline and call
+ // it.
+ auto createAlwaysInlineFuncAndCallIt =
+ [&](Region &region, llvm::StringRef funcName,
+ llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
+ assert(!region.empty() && "region cannot be empty");
+ LLVM::LLVMFuncOp func = createFuncOpForRegion(
+ loc, mod, region, funcName, rewriter, returnsValue);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
+ return call.getResult();
+ };
+
+ Value moldArg, newArg;
+ if (isPrivatizedByValue) {
+ moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
+ newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
+ } else {
+ moldArg = varPtr;
+ newArg = heapMem;
+ }
+
+ Value initializedVal;
+ if (!privatizer.getInitRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getInitRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+ {moldArg, newArg}, /*returnsValue=*/true);
+ else
+ initializedVal = newArg;
+
+ if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getCopyRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+ {moldArg, initializedVal}, /*returnsValue=*/true);
+
+ if (isPrivatizedByValue)
+ (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
+
+ // clone origOp, replace all uses of varPtr with heapMem and
+ // erase origOp.
+ auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ rewriter.modifyOpInPlace(clonedOp, [&]() {
+ clonedOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ rewriter.eraseOp(origOp);
+ return clonedOp;
+ };
+
+ // Now that we have set up the heap-allocated copy of the private
+ // variable, rewrite all the uses of the original variable with
+ // the heap-allocated variable.
+ rewriter.setInsertionPoint(targetOp);
+ mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
+ rewriter.setInsertionPoint(mapInfoOp);
+
+ // Fix any members that may use varPtr to now use heapMem
+ for (auto member : mapInfoOp.getMembers()) {
+ auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (!usesVarPtr(memberMapInfoOp))
+ continue;
+ memberMapInfoOp =
+ cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
+ rewriter.setInsertionPoint(memberMapInfoOp);
+
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+ rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+ }
+ }
+
+ // If the type of the private variable is not a pointer,
+ // which is typically the case with !fir.boxchar types, then
+ // we need to ensure that the new private variable is also
+ // not a pointer. Insert a load from heapMem right before
+ // targetOp.
+ if (isPrivatizedByValue) {
+ rewriter.setInsertionPoint(targetOp);
+ auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
+ varType, heapMem);
+ newPrivVars.push_back(newPrivVar);
+ }
+
+ // Deallocate
+ if (needsCleanupTask) {
+ if (!cleanupTaskOp) {
+ assert(fakeDependVar &&
+ "Need a valid value to set up a dependency");
+ rewriter.setInsertionPointAfter(targetOp);
+ omp::TaskOperands taskOperands;
+ auto inDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
+ taskOperands.dependKinds.push_back(inDepend);
+ taskOperands.dependVars.push_back(fakeDependVar);
+ cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
+ Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
+ rewriter.setInsertionPointToEnd(taskBlock);
+ omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
+ }
+ rewriter.setInsertionPointToStart(
+ &*cleanupTaskOp.getRegion().getBlocks().begin());
+ (void)createAlwaysInlineFuncAndCallIt(
+ privatizer.getDeallocRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
+ .str(),
+ {initializedVal}, /*returnsValue=*/false);
+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ LLVM::lookupOrCreateFreeFn(rewriter, mod);
+ assert(llvm::succeeded(freeFunc) &&
+ "Could not find free in the module");
+ (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
+ ValueRange{heapMem});
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+ if (fakeDependVar) {
+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
+ SmallVector<Attribute> newDependKinds;
+ if (!targetOp.getDependVars().empty()) {
+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
+ assert(dependKinds && "bad depend clause in omp::TargetOp");
+ llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
+ }
+ newDependKinds.push_back(outDepend);
+ ArrayAttr newDependKindsAttr =
+ ArrayAttr::get(rewriter.getContext(), newDependKinds);
+ targetOp.getDependVarsMutable().append(fakeDependVar);
+ targetOp.setDependKindsAttr(newDependKindsAttr);
+ }
+ rewriter.setInsertionPoint(targetOp);
+ targetOp.getPrivateVarsMutable().clear();
+ targetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
+ llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
+ ModuleOp mod, IRRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ BlockArgument blockArg;
+ if (!definingOp) {
+ blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
+ rewriter.setInsertionPointToStart(blockArg.getParentBlock());
+ } else {
+ rewriter.setInsertionPoint(definingOp);
+ }
+ Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+
+ const DataLayout &dl = DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+
+ Value sizeBytes = LLVM::ConstantOp::create(
+ rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ // Create a function for srcRegion and attribute it to be always_inline.
+ // The big assumption here is that srcRegion is one of init, copy or dealloc
+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
+ // to either be the same as the types of the two arguments of the region (for
+ // init and copy regions) or void as would be the case for dealloc regions.
+ LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
+ Region &srcRegion,
+ llvm::StringRef funcName,
+ IRRewriter &rewriter,
+ bool returnsValue = false) {
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+ Region clonedRegion;
+ IRMapping mapper;
+ srcRegion.cloneInto(&clonedRegion, mapper);
+
+ SmallVector<Type> paramTypes;
+ llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
+ Type resultType = returnsValue
+ ? srcRegion.getArgument(0).getType()
+ : LLVM::LLVMVoidType::get(rewriter.getContext());
+ LLVM::LLVMFunctionType funcType =
+ LLVM::LLVMFunctionType::get(resultType, paramTypes);
+
+ LLVM::LLVMFuncOp func =
+ LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
+ func.setAlwaysInline(true);
+ rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
+ func.getRegion().end());
+ for (auto &block : func.getRegion().getBlocks()) {
+ if (isa<omp::YieldOp>(block.getTerminator())) {
+ omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
+ yieldOp.getOperands());
+ }
+ }
+ return func;
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a9da6c2..744a595 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
+// Pattern to eliminate ExecuteRegionOp results which forward external
+// values from the region. In case there are multiple yield operations,
+// all of them must have the same operands in order for the pattern to be
+// applicable.
+struct ExecuteRegionForwardingEliminator
+ : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumResults() == 0)
+ return failure();
+
+ SmallVector<Operation *> yieldOps;
+ for (Block &block : op.getRegion()) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yield.getOperation());
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Check if all yield operations have the same operands.
+ auto yieldOpsOperands = yieldOps[0]->getOperands();
+ for (auto *yieldOp : yieldOps) {
+ if (yieldOp->getOperands() != yieldOpsOperands)
+ return failure();
+ }
+
+ SmallVector<Value> externalValues;
+ SmallVector<Value> internalValues;
+ SmallVector<Value> opResultsToReplaceWithExternalValues;
+ SmallVector<Value> opResultsToKeep;
+ for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
+ if (isValueFromInsideRegion(yieldedValue, op)) {
+ internalValues.push_back(yieldedValue);
+ opResultsToKeep.push_back(op.getResult(index));
+ } else {
+ externalValues.push_back(yieldedValue);
+ opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
+ }
+ }
+ // No yielded external values - nothing to do.
+ if (externalValues.empty())
+ return failure();
+
+ // There are yielded external values - create a new execute_region returning
+ // just the internal values.
+ SmallVector<Type> resultTypes;
+ for (Value value : internalValues)
+ resultTypes.push_back(value.getType());
+ auto newOp =
+ ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
+ newOp->setAttrs(op->getAttrs());
+
+ // Move old op's region to the new operation.
+ rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().end());
+
+ // Replace all yield operations with a new yield operation with updated
+ // results. scf.execute_region must have at least one yield operation.
+ for (auto *yieldOp : yieldOps) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
+ ValueRange(internalValues));
+ }
+
+ // Replace the old operation with the external values directly.
+ rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
+ externalValues);
+ // Replace the old operation's remaining results with the new operation's
+ // results.
+ rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ bool isValueFromInsideRegion(Value value,
+ ExecuteRegionOp executeRegionOp) const {
+ // Check if the value is defined within the execute_region
+ if (Operation *defOp = value.getDefiningOp())
+ return &executeRegionOp.getRegion() == defOp->getParentRegion();
+
+ // If it's a block argument, check if it's from within the region
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
+ return &executeRegionOp.getRegion() == blockArg.getParentRegion();
+
+ return false; // Value is from outside the region
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
+ ExecuteRegionForwardingEliminator>(context);
}
void ExecuteRegionOp::getSuccessorRegions(
@@ -2490,8 +2584,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+ constantTrue = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
@@ -2500,8 +2594,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+ constantFalse = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2..335ca1a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 1cba1bb..32eb286 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,96 @@
namespace mlir {
namespace tosa {
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
+TosaSpecificationVersion getMinVersion(const Profile &profile) {
+ switch (profile) {
+ case Profile::pro_int:
+ case Profile::pro_fp:
+ return TosaSpecificationVersion(1, 0);
+ case Profile::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA profile");
+}
+
+TosaSpecificationVersion getMinVersion(const Extension &extension) {
+ switch (extension) {
+ case Extension::int16:
+ case Extension::int4:
+ case Extension::bf16:
+ case Extension::fp8e4m3:
+ case Extension::fp8e5m2:
+ case Extension::fft:
+ case Extension::variable:
+ case Extension::controlflow:
+ case Extension::doubleround:
+ case Extension::inexactround:
+ case Extension::dynamic:
+ return TosaSpecificationVersion(1, 0);
+ case Extension::mxfp:
+ return TosaSpecificationVersion(1, 1);
+ case Extension::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA extension");
+}
+
+TosaSpecificationVersion getMinVersion(const Level &level) {
+ switch (level) {
+ case Level::eightK:
+ case Level::none:
+ return TosaSpecificationVersion(1, 0);
+ }
+ llvm_unreachable("Unknown TOSA level");
+}
+
+FailureOr<TargetEnv>
+TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
+ Location targetEnvAttrLoc) {
+ if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
+ return failure();
+
+ return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions());
+}
+
+LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
+ Location targetAttrLoc) {
+ TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
+
+ const auto isCompatibleWithTargetVersion =
+ [&](const auto &targetEnum, Location targetAttrLoc,
+ StringRef enumName) -> LogicalResult {
+ const TosaSpecificationVersion minRequiredVersion =
+ getMinVersion(targetEnum);
+ if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
+ return emitError(targetAttrLoc, enumName)
+ << " '" << stringifyEnum(targetEnum)
+ << "' is not compatible with the target version "
+ << stringifyVersion(targetVersion)
+ << ", minimum required version is "
+ << stringifyVersion(minRequiredVersion);
+ return success();
+ };
+
+ for (const auto &profile : targetAttr.getProfiles())
+ if (failed(
+ isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
+ return failure();
+ for (const auto &extension : targetAttr.getExtensions())
+ if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
+ "extension")))
+ return failure();
+ if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
+ "level")))
+ return failure();
+
+ return success();
+}
+
TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
-llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
- return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
-}
-
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf8016..a85ff10a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -76,28 +76,6 @@ template <typename OpTy>
struct PoolPadFoldAdaptor;
template <>
-struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
- using OpTy = tosa::AvgPool2dOp;
- static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
- const llvm::ArrayRef<int64_t> kernel = op.getKernel();
- if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
- newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
- return false;
- return true;
- }
- static bool checkPadConstCompliance(OpTy op, Value padConst) {
- return checkMatchingPadConstAndZp(padConst, op.getInputZp());
- }
- static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
- Value padInput, ArrayRef<int64_t> newPad) {
- rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
- op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
- op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
- op.getAccType());
- }
-};
-
-template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
@@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
};
} // namespace
-void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
- PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
- context);
-}
-
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
@@ -1001,8 +972,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
!outputTy.hasStaticShape())
return {};
- if (inputTy.getDimSize(getAxis()) == 1)
- return DenseElementsAttr::get(outputTy, 0);
+ const Type outputElementTy = getElementTypeOrSelf(outputTy);
+ if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+ const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+ const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+ return DenseElementsAttr::get(outputTy, zero);
+ }
return {};
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 00f84bc..6cd0eae 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
}
}
+ // special handling: block_size accepts a *bare* BlockSizeMode enum
+ if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+ if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeBlockSize(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid block_size value: " << kw;
+ auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
// Default path: parse any normal attribute literal, including fully qualified
// enum keyword
Attribute attr;
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
} else if (auto nanPropagationModeAttr =
dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
parser << nanPropagationModeAttr.getValue();
+ } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+ parser << blockSizeAttr.getValue();
} else {
parser.printAttribute(attr);
}
@@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
// verify that inType and outType have same element types
template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
- auto inputType = llvm::dyn_cast<TensorType>(inType);
- auto outputType = llvm::dyn_cast<TensorType>(outType);
- if (!inputType) {
- op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+ StringRef aName = "input",
+ StringRef bName = "output") {
+ auto aTType = llvm::dyn_cast<TensorType>(aType);
+ auto bTType = llvm::dyn_cast<TensorType>(bType);
+ if (!aTType) {
+ op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
return failure();
}
- if (!outputType) {
- op.emitOpError("expect shaped tensor for output, got ") << outType;
+ if (!bTType) {
+ op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
return failure();
}
- auto inputElementType = inputType.getElementType();
- auto outputElementType = outputType.getElementType();
- auto inputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
- auto outputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
- if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
- (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
- inputElementType != outputElementType) {
+ auto aElementType = aTType.getElementType();
+ auto bElementType = bTType.getElementType();
+ auto aQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+ auto bQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+ if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+ (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+ aElementType != bElementType) {
// only check if both element types are int/index/float/UniformQuantized
// eg, not sure how to check quant::QuantizedType
// this happens in test_conv2d_q_grouped_convolution in
// tfl-to-tosa-pipeline.mlir
- op.emitOpError("expect input and output to have same element type, got ")
- << inputElementType << " and " << outputElementType;
+ op.emitOpError("expect ")
+ << aName << " and " << bName << " to have same element type, got "
+ << aElementType << " and " << bElementType;
return failure();
}
return success();
@@ -1846,6 +1873,161 @@ LogicalResult MatMulOp::verify() {
return success();
}
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MatmulTBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+ const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+ if (aDataShape.hasRank()) {
+ outShape[0] = aDataShape.getDimSize(0);
+ outShape[1] = aDataShape.getDimSize(1);
+ }
+
+ const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+ : outShape[0];
+ outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+ : outShape[1];
+ }
+
+ // If B batch size is 1, it is broadcast across A's batch size
+ const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+ if (bDataShape.hasRank()) {
+ const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+ if (bDataBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+ outShape[2] = bDataShape.getDimSize(1);
+ }
+
+ const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+ if (bScaleBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+ outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+ : outShape[2];
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+ // Verify same input data types
+ const Type aDataType = getAData().getType();
+ const Type bDataType = getBData().getType();
+ if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+ "B_data")))
+ return failure();
+
+ auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) -> LogicalResult {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+ };
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t D = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+ if (aDataShape.hasRank()) {
+ N = aDataShape.getDimSize(0);
+ H = aDataShape.getDimSize(1);
+ C = aDataShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+ "height")))
+ return failure();
+ multiplesOfC = aScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+ if (bDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+ "channels")))
+ return failure();
+ W = bDataShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+ "width")) ||
+ failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+ "b_scale", "C/block_size")))
+ return failure();
+ }
+
+ // Verify batch size is broadcast compatible
+ if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+ return emitOpError("expect B matrix batch size to be broadcast compatible "
+ "with A, got D=")
+ << D << " vs N=" << N;
+
+ // Verify C is a multiple of block size
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(C) && C % blockSize != 0)
+ return emitOpError("expect C to be a multiple of block size, got C=")
+ << C << ", block_size=" << blockSize;
+
+ // Verify multiplesOfC is C / block size
+ if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != C / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal C/block_size (")
+ << C << "/" << blockSize << ")"
+ << ", got " << multiplesOfC;
+
+ // Verify output shape
+ N = ShapedType::isDynamic(N) ? D : N;
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f072e3e..e965ae0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
+ // micro-scaling formats
+ const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
+ const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
+ const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
+ const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
@@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
+ POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
@@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
+ } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
+ return {"fp6e2m3"};
+ } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
+ return {"fp6e3m2"};
+ } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
+ return {"fp4e2m1"};
+ } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
+ return {"fp8e8m0"};
}
llvm_unreachable("unknown type");
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 82f2f7e..3f874d9 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(TransposeConv2D);
CHECK_SIZES(FFT2d);
CHECK_SIZES(MatMul);
+ CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
@@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
- Float8E5M2Type>(type);
- }
- if (auto intTy = dyn_cast<IntegerType>(type)) {
+ Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
+ Float6E3M2FNType, Float8E8M0FNUType>(type);
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
@@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
+ ModuleOp modOp = getOperation();
+ const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
+ const auto maybeTargetEnv =
+ tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
+ if (failed(maybeTargetEnv))
+ return signalPassFailure();
+ targetEnv = *maybeTargetEnv;
+
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
- targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
-
- getOperation().walk([&](Operation *op) {
+ modOp.walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 8f46ad6..ef49c86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final
for (int64_t i = 0; i < origNumElems; ++i)
promoteMask[i] = i;
- Value promotedInput = rewriter.create<vector::ShuffleOp>(
- shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
- promoteMask);
+ Value promotedInput =
+ vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType,
+ inputToPromote, inputToPromote, promoteMask);
// Create the final shuffle with the promoted inputs.
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7c019e7..8b5e950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -341,13 +341,18 @@ private:
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
+/// distributed dimensions. If the number of results is zero there is no
+/// distribution (i.e. original type is returned).
+/// Otherwise, The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
+ // If the map has zero results, return the original type.
+ if (map.getNumResults() == 0)
+ return originalType;
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1599ae9..24e9095 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
OpBuilder &builder) {
auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
- return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+ return ArithOp::create(builder, loc, aVal, bVal).getResult();
}
// a helper utility to perform division operation on OpFoldResult and int64_t.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 26770b3..d09dc19 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
- // If no layout is specified, assume the inner most dimension is distributed
- // for now.
+ // If no layout is specified, that means no distribution.
if (!layout)
- return AffineMap::getMultiDimMapWithTargets(
- vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ return AffineMap::getMultiDimMapWithTargets(vecRank, {},
+ val.getContext());
+ // Expecting vector and layout rank to match.
+ assert(layout.getRank() == vecRank &&
+ "Expecting vector and layout rank to match");
+ // A dimension is distributed only if layout suggests there are
+ // multiple lanes assigned for this dimension and the shape can be evenly
+ // distributed to those lanes.
SmallVector<unsigned int> distributedDims;
for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
- if (v > 1)
+ if (v > 1 && vecType.getShape()[i] % v == 0)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
@@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled =
- builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 31a967d..9fc5ad9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
baseTileValues);
- auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+ auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
// Get subgroup id
Value sgId =
@@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
SmallVector<Value, 2> strideConsts;
strideConsts.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ arith::ConstantIndexOp::create(rewriter, loc, colStride));
if (rows > 1)
strideConsts.insert(
strideConsts.begin(),
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+ arith::ConstantIndexOp::create(rewriter, loc, rowStride));
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < strideConsts.size(); ++i) {
- Value mul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
- mulOffset = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), mulOffset, mul);
+ Value mul =
+ arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
+ offsets[i], strideConsts[i]);
+ mulOffset = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
}
// Broadcast to baseConstVec size
- auto bcastOffset = rewriter.create<vector::BroadcastOp>(
- loc, baseConstVec.getType(), mulOffset);
+ auto bcastOffset = vector::BroadcastOp::create(
+ rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
@@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp
SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
- auto newShapeCast =
- rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ newResultType, src);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
@@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp
SmallVector<Value> newReductions;
for (auto sgSrc : adaptor.getSource()) {
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
- op.getReductionDims());
+ auto newOp = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
+ adaptor.getAcc()[0], op.getReductionDims());
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
index 6ef1529..c712c64b 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt
@@ -21,6 +21,6 @@ set_property(TARGET MLIRSparseTensorRuntime PROPERTY CXX_STANDARD 17)
check_cxx_compiler_flag(-Wweak-vtables
COMPILER_SUPPORTS_WARNING_WEAK_VTABLES)
if(COMPILER_SUPPORTS_WARNING_WEAK_VTABLES)
- target_compile_options(MLIRSparseTensorRuntime PUBLIC
+ target_compile_options(MLIRSparseTensorRuntime PRIVATE
"-Wweak-vtables")
endif()
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 4d81918..776b5c6 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -378,10 +378,8 @@ struct SourceMgrDiagnosticHandlerImpl {
}
// Otherwise, try to load the source file.
- auto bufferOrErr = llvm::MemoryBuffer::getFile(filename);
- if (!bufferOrErr)
- return 0;
- unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc());
+ std::string ignored;
+ unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
filenameToBufId[filename] = id;
return id;
}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index dd413d2de..d7e321a 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -33,6 +33,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -80,6 +81,7 @@ void mlir::registerAllPasses() {
memref::registerMemRefPasses();
shard::registerShardPasses();
ml_program::registerMLProgramPasses();
+ omp::registerOpenMPPasses();
quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8de49dd..f284540 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -357,14 +357,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("priority");
};
auto checkPrivate = [&todo](auto op, LogicalResult &result) {
- if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
- // Privatization is supported only for included target tasks.
- if (!op.getPrivateVars().empty() && op.getNowait())
- result = todo("privatization for deferred target tasks");
- } else {
- if (!op.getPrivateVars().empty() || op.getPrivateSyms())
- result = todo("privatization");
- }
+ if (!op.getPrivateVars().empty() || op.getPrivateSyms())
+ result = todo("privatization");
};
auto checkReduction = [&todo](auto op, LogicalResult &result) {
if (isa<omp::TeamsOp>(op))
@@ -451,7 +445,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkDevice(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
- checkPrivate(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
@@ -3833,6 +3826,58 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
+// Convert the MLIR map flag set to the runtime map flag set for embedding
+// in LLVM-IR. This is important as the two bit-flag lists do not correspond
+// 1-to-1 as there's flags the runtime doesn't care about and vice versa.
+// Certain flags are discarded here such as RefPtee and co.
+static llvm::omp::OpenMPOffloadMappingFlags
+convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
+ auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
+ return (mlirFlags & flag) == flag;
+ };
+
+ llvm::omp::OpenMPOffloadMappingFlags mapType =
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::to))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::from))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::always))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::del))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::return_param))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::priv))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::literal))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::implicit))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::close))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::present))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
+
+ if (mapTypeToBool(omp::ClauseMapFlags::attach))
+ mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
+
+ return mapType;
+}
+
static void collectMapDataFromMapOperands(
MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
@@ -3880,8 +3925,7 @@ static void collectMapDataFromMapOperands(
getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
mapData.BaseType.back(), builder, moduleTranslation));
mapData.MapClause.push_back(mapOp.getOperation());
- mapData.Types.push_back(
- llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
+ mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType()));
mapData.Names.push_back(LLVM::createMappingInformation(
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None);
@@ -3950,8 +3994,7 @@ static void collectMapDataFromMapOperands(
Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
- auto mapType =
- static_cast<llvm::omp::OpenMPOffloadMappingFlags>(mapOp.getMapType());
+ auto mapType = convertClauseMapFlags(mapOp.getMapType());
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
mapData.OriginalValue.push_back(origValue);
@@ -4299,8 +4342,7 @@ static void processMapMembersWithParent(
// in part as we currently have substantially less information on the data
// being mapped at this stage.
if (checkIfPointerMap(memberClause)) {
- auto mapFlag =
- llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
+ auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
@@ -4319,8 +4361,7 @@ static void processMapMembersWithParent(
// Same MemberOfFlag to indicate its link with parent and other members
// of.
- auto mapFlag =
- llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
+ auto mapFlag = convertClauseMapFlags(memberClause.getMapType());
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d9ad8fb..6492708 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -702,8 +702,8 @@ spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
// RAII guard to reset the insertion point to previous value when done.
OpBuilder::InsertionGuard insertionGuard(opBuilder);
opBuilder.setInsertionPoint(graphARM);
- opBuilder.create<spirv::GraphEntryPointARMOp>(
- unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ spirv::GraphEntryPointARMOp::create(
+ opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
opBuilder.getArrayAttr(interface));
return success();
@@ -736,7 +736,7 @@ spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
std::string graphName = getGraphSymbol(graphID);
auto graphOp =
- opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
+ spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
curGraph = graphMap[graphID] = graphOp;
Block *entryBlock = graphOp.addEntryBlock();
LLVM_DEBUG({
@@ -844,7 +844,7 @@ spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
LogicalResult
spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
// Create GraphOutputsARM instruction.
- opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+ spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
// Process OpGraphEndARM.
if (!operands.empty()) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b56e778..b88fbaa 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -260,9 +260,9 @@ static std::string getDecorationName(StringRef attrName) {
}
template <typename AttrTy, typename EmitF>
-LogicalResult processDecorationList(Location loc, Decoration decoration,
- Attribute attrList, StringRef attrName,
- EmitF emitter) {
+static LogicalResult processDecorationList(Location loc, Decoration decoration,
+ Attribute attrList,
+ StringRef attrName, EmitF emitter) {
auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
if (!arrayAttr) {
return emitError(loc, "expecting array attribute of ")
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 366ba8f..048e964 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -406,7 +406,7 @@ private:
auto returnOperands = popOperands(resTypes);
if (failed(returnOperands))
return failure();
- builder.create<BlockReturnOp>(opLoc, *returnOperands);
+ BlockReturnOp::create(builder, opLoc, *returnOperands);
LDBG() << "end of parsing of a block";
return bodyParsingRes->endingByte;
}
@@ -1000,7 +1000,7 @@ parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
builder.setInsertionPointToEnd(curBlock);
auto blockOp =
- builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor);
+ OpToCreate::create(builder, *currentOpLoc, *inputOps, successor);
auto *blockBody = blockOp.createBlock();
if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
return failure();
@@ -1047,8 +1047,8 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
auto *successor =
builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
builder.setInsertionPointToEnd(curBlock);
- auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(),
- *inputOps, successor);
+ auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(),
+ *inputOps, successor);
auto *ifEntryBlock = ifOp.createIfBlock();
constexpr auto ifElseFilter =
ByteSequence<WasmBinaryEncoding::endByte,
@@ -1091,9 +1091,9 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
auto branchArgs = popOperands(inputTypes);
if (failed(branchArgs))
return failure();
- builder.create<BranchIfOp>(*currentOpLoc, condition->front(),
- builder.getUI32IntegerAttr(*level), *branchArgs,
- elseBlock);
+ BranchIfOp::create(builder, *currentOpLoc, condition->front(),
+ builder.getUI32IntegerAttr(*level), *branchArgs,
+ elseBlock);
builder.setInsertionPointToStart(elseBlock);
return {*branchArgs};
}
@@ -1115,7 +1115,7 @@ ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
if (failed(inOperands))
return failure();
auto callOp =
- builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands);
+ FuncCallOp::create(builder, loc, resTypes, callee.symbol, *inOperands);
return {callOp.getResults()};
}
@@ -1391,8 +1391,8 @@ inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
auto operand = popOperands(intype);
if (failed(operand))
return failure();
- auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(),
- extraArgs...);
+ auto op = opType::create(builder, *currentOpLoc, outType, operand->front(),
+ extraArgs...);
LDBG() << "Built operation: " << op;
return {{op.getResult()}};
}