aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/examples/standalone/python/CMakeLists.txt11
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td49
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h135
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACC.h8
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td15
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h5
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt2
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/Transforms/CMakeLists.txt5
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h26
-rw-r--r--mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td26
-rw-r--r--mlir/include/mlir/IR/Builders.h5
-rw-r--r--mlir/include/mlir/IR/Value.h10
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp10
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp10
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp21
-rw-r--r--mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp26
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp24
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp137
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp11
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp87
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt13
-rw-r--r--mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp26
-rw-r--r--mlir/lib/Dialect/OpenACC/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp14
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp28
-rw-r--r--mlir/lib/Dialect/OpenMP/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt14
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp445
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp8
-rw-r--r--mlir/lib/Dialect/Shard/Transforms/Partition.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp27
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp31
-rw-r--r--mlir/lib/RegisterAllPasses.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp11
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp8
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp20
-rw-r--r--mlir/python/CMakeLists.txt13
-rw-r--r--mlir/test/Dialect/MemRef/canonicalize.mlir126
-rw-r--r--mlir/test/Dialect/MemRef/expand-strided-metadata.mlir127
-rw-r--r--mlir/test/Dialect/OpenACC/ops.mlir43
-rw-r--r--mlir/test/Dialect/OpenACC/support-analysis-varname.mlir88
-rw-r--r--mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare-by-value.mlir157
-rw-r--r--mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir201
-rw-r--r--mlir/test/Dialect/Tosa/canonicalize.mlir9
-rw-r--r--mlir/test/Dialect/XeGPU/subgroup-distribute.mlir51
-rw-r--r--mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir23
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir11
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir24
-rw-r--r--mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir34
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir-invalid.mlir32
-rw-r--r--mlir/test/Target/LLVMIR/openmp-todo.mlir18
-rw-r--r--mlir/test/lib/Dialect/OpenACC/CMakeLists.txt2
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp2
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp73
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp2
-rw-r--r--mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp75
65 files changed, 2104 insertions, 373 deletions
diff --git a/mlir/examples/standalone/python/CMakeLists.txt b/mlir/examples/standalone/python/CMakeLists.txt
index 905c9449..df19fa8 100644
--- a/mlir/examples/standalone/python/CMakeLists.txt
+++ b/mlir/examples/standalone/python/CMakeLists.txt
@@ -74,7 +74,12 @@ add_mlir_python_common_capi_library(StandalonePythonCAPI
set(StandalonePythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}")
-if(NOT CMAKE_CROSSCOMPILING)
+set(_mlir_python_stubgen_enabled ON)
+if(CMAKE_CROSSCOMPILING OR (NOT LLVM_USE_SANITIZER STREQUAL ""))
+ set(_mlir_python_stubgen_enabled OFF)
+endif()
+
+if(_mlir_python_stubgen_enabled)
# Everything here is very tightly coupled. See the ample descriptions at the bottom of
# mlir/python/CMakeLists.txt.
@@ -141,7 +146,7 @@ set(_declared_sources
)
# For an external projects build, the MLIRPythonExtension.Core.type_stub_gen
# target already exists and can just be added to DECLARED_SOURCES.
-if(EXTERNAL_PROJECT_BUILD AND (NOT CMAKE_CROSSCOMPILING))
+if(EXTERNAL_PROJECT_BUILD AND _mlir_python_stubgen_enabled)
list(APPEND _declared_sources MLIRPythonExtension.Core.type_stub_gen)
endif()
@@ -153,7 +158,7 @@ add_mlir_python_modules(StandalonePythonModules
StandalonePythonCAPI
)
-if(NOT CMAKE_CROSSCOMPILING)
+if(_mlir_python_stubgen_enabled)
if(NOT EXTERNAL_PROJECT_BUILD)
add_dependencies(StandalonePythonModules "${_mlir_typestub_gen_target}")
endif()
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d959464..4f48385 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1872,6 +1872,55 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}
+class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
+: NVVM_Op<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2"> {
+ let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2";
+ let description = [{
+ This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] #
+ !if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
+ !tolower(dstType) # [{.
+
+ The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
+ }] #
+ !if(!eq(dstType, "F16"),
+ [{The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction."}], "") # [{
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
+ let arguments = !if(!eq(dstType, "F16"),
+ (ins srcArgType:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType),
+ (ins srcArgType:$src,
+ TypeAttr:$srcType));
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::Convert}] # srcType # [{x2To}] # dstType #
+ [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
+def NVVM_ConvertF8x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">;
+def NVVM_ConvertF8x2ToBF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
+def NVVM_ConvertF6x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">;
+def NVVM_ConvertF4x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
new file mode 100644
index 0000000..0833462
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h
@@ -0,0 +1,135 @@
+//===- OpenACCSupport.h - OpenACC Support Interface -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the OpenACCSupport analysis interface, which provides
+// extensible support for OpenACC passes. Custom implementations
+// can be registered to provide pipeline and dialect-specific information
+// that cannot be adequately expressed through type or operation interfaces
+// alone.
+//
+// Usage Pattern:
+// ==============
+//
+// A pass that needs this functionality should call
+// getAnalysis<OpenACCSupport>(), which will provide either:
+// - A cached version if previously initialized, OR
+// - A default implementation if not previously initialized
+//
+// This analysis is never invalidated (isInvalidated returns false), so it only
+// needs to be initialized once and will persist throughout the pass pipeline.
+//
+// Registering a Custom Implementation:
+// =====================================
+//
+// If a custom implementation is needed, create a pass that runs BEFORE the pass
+// that needs the analysis. In this setup pass, use
+// getAnalysis<OpenACCSupport>() followed by setImplementation() to register
+// your custom implementation. The custom implementation will need to provide
+// implementation for all methods defined in the `OpenACCSupportTraits::Concept`
+// class.
+//
+// Example:
+// void MySetupPass::runOnOperation() {
+// OpenACCSupport &support = getAnalysis<OpenACCSupport>();
+// support.setImplementation(MyCustomImpl());
+// }
+//
+// void MyAnalysisConsumerPass::runOnOperation() {
+// OpenACCSupport &support = getAnalysis<OpenACCSupport>();
+// std::string name = support.getVariableName(someValue);
+// // ... use the analysis results
+// }
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
+#define MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
+
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/AnalysisManager.h"
+#include <memory>
+#include <string>
+
+namespace mlir {
+namespace acc {
+
+namespace detail {
+/// This class contains internal trait classes used by OpenACCSupport.
+/// It follows the Concept-Model pattern used throughout MLIR (e.g., in
+/// AliasAnalysis and interface definitions).
+struct OpenACCSupportTraits {
+ class Concept {
+ public:
+ virtual ~Concept() = default;
+
+ /// Get the variable name for a given MLIR value.
+ virtual std::string getVariableName(Value v) = 0;
+ };
+
+ /// This class wraps a concrete OpenACCSupport implementation and forwards
+ /// interface calls to it. This provides type erasure, allowing different
+ /// implementation types to be used interchangeably without inheritance.
+ template <typename ImplT>
+ class Model final : public Concept {
+ public:
+ explicit Model(ImplT &&impl) : impl(std::forward<ImplT>(impl)) {}
+ ~Model() override = default;
+
+ std::string getVariableName(Value v) final {
+ return impl.getVariableName(v);
+ }
+
+ private:
+ ImplT impl;
+ };
+};
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// OpenACCSupport
+//===----------------------------------------------------------------------===//
+
+class OpenACCSupport {
+ using Concept = detail::OpenACCSupportTraits::Concept;
+ template <typename ImplT>
+ using Model = detail::OpenACCSupportTraits::Model<ImplT>;
+
+public:
+ OpenACCSupport() = default;
+ OpenACCSupport(Operation *op) {}
+
+ /// Register a custom OpenACCSupport implementation. Only one implementation
+ /// can be registered at a time; calling this replaces any existing
+ /// implementation.
+ template <typename AnalysisT>
+ void setImplementation(AnalysisT &&analysis) {
+ impl =
+ std::make_unique<Model<AnalysisT>>(std::forward<AnalysisT>(analysis));
+ }
+
+ /// Get the variable name for a given value.
+ ///
+ /// \param v The MLIR value to get the variable name for.
+ /// \return The variable name, or an empty string if unavailable.
+ std::string getVariableName(Value v);
+
+ /// Signal that this analysis should always be preserved so that
+ /// underlying implementation registration is not lost.
+ bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) {
+ return false;
+ }
+
+private:
+ /// The registered custom implementation (if any).
+ std::unique_ptr<Concept> impl;
+};
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_ANALYSIS_OPENACCSUPPORT_H
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index e2a60f5..05d2316 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -46,10 +46,10 @@
mlir::acc::CopyinOp, mlir::acc::CreateOp, mlir::acc::PresentOp, \
mlir::acc::NoCreateOp, mlir::acc::AttachOp, mlir::acc::DevicePtrOp, \
mlir::acc::GetDevicePtrOp, mlir::acc::PrivateOp, \
- mlir::acc::FirstprivateOp, mlir::acc::UpdateDeviceOp, \
- mlir::acc::UseDeviceOp, mlir::acc::ReductionOp, \
- mlir::acc::DeclareDeviceResidentOp, mlir::acc::DeclareLinkOp, \
- mlir::acc::CacheOp
+ mlir::acc::FirstprivateOp, mlir::acc::FirstprivateMapInitialOp, \
+ mlir::acc::UpdateDeviceOp, mlir::acc::UseDeviceOp, \
+ mlir::acc::ReductionOp, mlir::acc::DeclareDeviceResidentOp, \
+ mlir::acc::DeclareLinkOp, mlir::acc::CacheOp
#define ACC_DATA_EXIT_OPS \
mlir::acc::CopyoutOp, mlir::acc::DeleteOp, mlir::acc::DetachOp, \
mlir::acc::UpdateHostOp
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e78cdbe..2f87975 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -787,6 +787,21 @@ def OpenACC_FirstprivateOp : OpenACC_DataEntryOp<"firstprivate",
let extraClassDeclaration = extraClassDeclarationBase;
}
+// The mapping of firstprivate cannot be represented through an `acc.copyin`
+// since that operation includes present counter updates (and private variables
+// do not impact counters). Instead, the below operation is used to represent
+// the mapping of that initial value which can be used to initialize the private
+// copies.
+def OpenACC_FirstprivateMapInitialOp : OpenACC_DataEntryOp<"firstprivate_map",
+ "mlir::acc::DataClause::acc_firstprivate", "", [],
+ (ins Arg<OpenACC_AnyPointerOrMappableType,"Host variable",[MemRead]>:$var)> {
+ let summary = "Used to decompose firstprivate semantics and represents the "
+ "mapping of the initial value.";
+ let results = (outs Arg<OpenACC_AnyPointerOrMappableType,
+ "Accelerator mapped variable",[MemWrite]>:$accVar);
+ let extraClassDeclaration = extraClassDeclarationBase;
+}
+
//===----------------------------------------------------------------------===//
// 2.5.15 reduction clause
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
index 378f434..0ee88c6 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
@@ -38,6 +38,11 @@ std::optional<ClauseDefaultValue> getDefaultAttr(mlir::Operation *op);
/// Get the type category of an OpenACC variable.
mlir::acc::VariableTypeCategory getTypeCategory(mlir::Value var);
+/// Attempts to extract the variable name from a value by walking through
+/// view-like operations until an `acc.var_name` attribute is found. Returns
+/// empty string if no name is found.
+std::string getVariableName(mlir::Value v);
+
} // namespace acc
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
index b6c8dba..691163d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
set(LLVM_TARGET_DEFINITIONS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Frontend/OpenMP/OMP.td)
mlir_tablegen(OmpCommon.td --gen-directive-decl --directives-dialect=OpenMP)
add_mlir_dialect_tablegen_target(omp_common_td)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..22f0d92
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name OpenMP)
+add_public_tablegen_target(MLIROpenMPPassIncGen)
+
+add_mlir_doc(Passes OpenMPPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h
new file mode 100644
index 0000000..21b6d1f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.h
@@ -0,0 +1,26 @@
+//===- Passes.h - OpenMP Pass Construction and Registration -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+namespace omp {
+
+/// Generate the code for registering conversion passes.
+#define GEN_PASS_DECL
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
new file mode 100644
index 0000000..1fde7e0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Transforms/Passes.td
@@ -0,0 +1,26 @@
+//===-- Passes.td - OpenMP pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES
+#define MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def PrepareForOMPOffloadPrivatizationPass : Pass<"omp-offload-privatization-prepare", "ModuleOp"> {
+ let summary = "Prepare OpenMP maps for privatization for deferred target tasks";
+ let description = [{
+ When generating LLVMIR for privatized variables in an OpenMP offloading directive (eg. omp::TargetOp)
+ that creates a deferred target task (when the nowait clause is used), we need to copy the privatized
+ variable out of the stack of the generating task and into the heap so that the deferred target task
+ can still access it. However, if such a privatized variable is also mapped, typically the case for
+ allocatables, then the corresponding `omp::MapInfoOp` needs to be fixed up to map the new heap-allocated
+ variable and not the original variable.
+ }];
+ let dependentDialects = ["LLVM::LLVMDialect"];
+}
+#endif // MLIR_DIALECT_OPENMP_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 9205f16..3ba6818 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -502,6 +502,7 @@ private:
public:
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
+ [[deprecated("Use OpTy::create instead")]]
OpTy create(Location location, Args &&...args) {
OperationState state(location,
getCheckRegisteredInfo<OpTy>(location.getContext()));
@@ -517,9 +518,9 @@ public:
/// the results of the operation.
///
/// Note: This performs opportunistic eager folding during IR construction.
- /// The folders are designed to operate efficiently on canonical IR, which
+ /// The folders are designed to operate efficiently on canonical IR, which
/// this API does not enforce. Complete folding isn't only expected in the
- /// context of canonicalization which intertwine folders with pattern
+ /// context of canonicalization which intertwine folders with pattern
/// rewrites until fixed-point.
template <typename OpTy, typename... Args>
void createOrFold(SmallVectorImpl<Value> &results, Location location,
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 4d6d89f..af58778 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
template <typename Ty>
struct TypedValue : Value {
using Value::Value;
+ using ValueType = Ty;
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
+ /// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
+ /// to A.
+ template <typename ToTy,
+ typename = typename std::enable_if<std::is_assignable<
+ typename ToTy::ValueType &, Ty>::value>::type>
+ operator ToTy() const {
+ return llvm::cast<ToTy>(*this);
+ }
+
/// Return the known Type
Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 85f0fd1d..9b15435 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1927,16 +1927,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
else
llvm_unreachable("unsupported row length");
- const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
- const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
+ Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
- const Value isEqual =
- rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
+ Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
+ LLVM::ICmpPredicate::eq, vdst0, v);
// Per `permlane(16|32)` semantics: if the first extracted element equals
// 'v', the result is the second element; otherwise it is the first.
Value vdstNew =
- rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
+ LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
permuted.emplace_back(vdstNew);
}
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 42099aa..12adfe1 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -93,11 +93,11 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
Location loc = op.getLoc();
Value exponentReal =
- rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
- Value zeroImag = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(exponentFloatType));
- Value exponent = rewriter.create<complex::CreateOp>(
- loc, op.getLhs().getType(), exponentReal, zeroImag);
+ arith::SIToFPOp::create(rewriter, loc, exponentFloatType, op.getRhs());
+ Value zeroImag = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getZeroAttr(exponentFloatType));
+ Value exponent = complex::CreateOp::create(
+ rewriter, loc, op.getLhs().getType(), exponentReal, zeroImag);
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
exponent, op.getFastmathAttr());
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 5613e02..0fe7239 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -937,14 +937,14 @@ struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
auto elementType = cast<FloatType>(type.getElementType());
Value floatExponent =
- builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs());
+ arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
Value zero = arith::ConstantOp::create(
builder, elementType, builder.getFloatAttr(elementType, 0.0));
Value complexExponent =
complex::CreateOp::create(builder, type, floatExponent, zero);
- auto pow = builder.create<complex::PowOp>(
- type, adaptor.getLhs(), complexExponent, op.getFastmathAttr());
+ auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
+ complexExponent, op.getFastmathAttr());
rewriter.replaceOp(op, pow.getResult());
return success();
}
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 852c50c..d64c4d6 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -500,19 +500,19 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
- auto one = rewriter.create<LLVM::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(1));
sinPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
cosPtr =
- rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
}
createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
op);
- auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
- auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+ auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+ auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
maybeTrunc(cosResult, inputType, rewriter)});
@@ -522,14 +522,15 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType()))
- return rewriter.create<LLVM::FPExtOp>(
- operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
return operand;
}
Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
if (operand.getType() != type)
- return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
return operand;
}
@@ -556,7 +557,7 @@ private:
}
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
- rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
}
};
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 229e40e..7cce324 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -142,8 +142,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
auto structType = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {llvmOperandType, llvmOperandType});
- auto sincosOp = rewriter.create<LLVM::SincosOp>(
- loc, structType, adaptor.getOperand(), attrs.getAttrs());
+ auto sincosOp = LLVM::SincosOp::create(
+ rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 519d9c8..71e3f88 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -394,9 +394,9 @@ private:
if (!convertedType)
return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
- emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
- loc, emitc::LValueType::get(convertedType), noInit);
- rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ auto var = emitc::VariableOp::create(
+ rewriter, loc, emitc::LValueType::get(convertedType), noInit);
+ emitc::AssignOp::create(rewriter, loc, var.getResult(), init);
loopVars.push_back(var);
}
@@ -411,11 +411,11 @@ private:
// Create a global boolean variable to store the loop condition state.
Type i1Type = IntegerType::get(context, 1);
auto globalCondition =
- rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
- emitc::OpaqueAttr::get(context, ""));
+ emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
Value conditionVal = globalCondition.getResult();
- auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+ auto loweredDo = emitc::DoOp::create(rewriter, loc);
// Convert region types to match the target dialect type system.
if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
@@ -450,12 +450,12 @@ private:
// Convert scf.condition to condition variable assignment.
Value condition = rewriter.getRemappedValue(condOp.getCondition());
- rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+ emitc::AssignOp::create(rewriter, loc, conditionVal, condition);
// Wrap body region in conditional to preserve scf semantics. Only create
// ifOp if after-region is non-empty.
if (whileOp.getAfterBody()->getOperations().size() > 1) {
- auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+ auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false);
// Prepare the after region (loop body) for merging.
Block *afterBlock = &whileOp.getAfter().front();
@@ -480,8 +480,8 @@ private:
Block *condBlock = rewriter.createBlock(&condRegion);
rewriter.setInsertionPointToStart(condBlock);
- auto exprOp = rewriter.create<emitc::ExpressionOp>(
- loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ auto exprOp = emitc::ExpressionOp::create(
+ rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false);
Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
// Set up the expression block to load the condition variable.
@@ -490,12 +490,12 @@ private:
// Load the condition value and yield it as the expression result.
Value cond =
- rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
- rewriter.create<emitc::YieldOp>(loc, cond);
+ emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0));
+ emitc::YieldOp::create(rewriter, loc, cond);
// Yield the expression as the condition region result.
rewriter.setInsertionPointToEnd(condBlock);
- rewriter.create<emitc::YieldOp>(loc, exprOp);
+ emitc::YieldOp::create(rewriter, loc, exprOp);
return success();
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 00df14b1..29afdc2 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -232,16 +232,16 @@ static Value createLinalgBodyCalculationForElementwiseOp(
}
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
- zpAddValue = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
+ zpAddValue = arith::ConstantOp::create(
+ rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
} else {
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
auto arg1 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
auto arg2 =
- rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
+ arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
zpAddValue =
- rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
+ arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
}
// The negation can be applied by doing:
@@ -1402,8 +1402,8 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
auto elemType = inputType.getElementType();
auto collapsedType = RankedTensorType::get({}, elemType);
// Emit the collapse op
- return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
- reassociation);
+ return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
+ reassociation);
}
static llvm::SmallVector<int8_t>
@@ -1443,7 +1443,7 @@ static void setupLinalgGenericOpInputAndIndexingMap(
IntegerAttr intAttr = isShift
? rewriter.getI8IntegerAttr(values.front())
: rewriter.getI32IntegerAttr(values.front());
- constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ constant = arith::ConstantOp::create(rewriter, loc, intAttr);
} else {
auto elementType =
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
@@ -1511,14 +1511,14 @@ static Value getExtendZp(OpBuilder &builder, Type valueTy,
.getResult(0);
}
if (zpTy.isUnsignedInteger()) {
- return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ return arith::ExtUIOp::create(builder, loc, extendType, result);
} else {
- return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ return arith::ExtSIOp::create(builder, loc, extendType, result);
}
}
} else {
- return builder.create<arith::ConstantOp>(
- loc, IntegerAttr::get(extendType, *maybeZp));
+ return arith::ConstantOp::create(builder, loc,
+ IntegerAttr::get(extendType, *maybeZp));
}
return result;
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 2a8c330..f0de4db 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
index d4ff095..37a45d4 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
@@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms
MLIRPass
MLIRTransforms
MLIRNVVMDialect
+ MLIROpenMPDialect
)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9a8a63e..794dda9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
for (auto [pos, dim] : llvm::enumerate(type.getShape())) {
if (!ShapedType::isDynamic(dim))
continue;
- Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos);
- auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst);
+ Value cst =
+ arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos);
+ auto dimOp =
+ tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst);
preservedOps.insert(dimOp);
dynamicDims.push_back(dimOp);
}
- auto allocation = rewriter.create<bufferization::AllocTensorOp>(
- tensor.getLoc(), type, dynamicDims);
+ auto allocation = bufferization::AllocTensorOp::create(
+ rewriter, tensor.getLoc(), type, dynamicDims);
// Set memory space if provided.
if (getMemorySpaceAttr())
allocation.setMemorySpaceAttr(getMemorySpaceAttr());
@@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter,
// Only insert a materialization (typically bufferizes to a copy) when the
// value may be read from.
if (needsMaterialization) {
- auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>(
- tensor.getLoc(), tensor, allocated);
+ auto copy = bufferization::MaterializeInDestinationOp::create(
+ rewriter, tensor.getLoc(), tensor, allocated);
preservedOps.insert(copy);
promoted.push_back(copy.getResult());
} else {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a..5e10ba3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,33 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = index::CmpOp::create(
+ builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
+ iterationDomainIsNonDegenerate =
+ arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +138,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b7..c551fba 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
atLeastOneReplacement |= replaceConstantUsesOf(
builder, getLoc(), getStrides(), getConstifiedMixedStrides());
+ // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
+ if (auto prev = getSource().getDefiningOp<CastOp>())
+ if (isa<MemRefType>(prev.getSource().getType())) {
+ getSourceMutable().assign(prev.getSource());
+ atLeastOneReplacement = true;
+ }
+
return success(atLeastOneReplacement);
}
@@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index d35566a..bd02516 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder
}
};
-/// Replace `base, offset, sizes, strides =
-/// extract_strided_metadata(
-/// cast(src) to dstTy)`
-/// With
-/// ```
-/// base, ... = extract_strided_metadata(src)
-/// offset = !dstTy.srcOffset.isDynamic()
-/// ? dstTy.srcOffset
-/// : extract_strided_metadata(src).offset
-/// sizes = for each srcSize in dstTy.srcSizes:
-/// !srcSize.isDynamic()
-/// ? srcSize
-// : extract_strided_metadata(src).sizes[i]
-/// strides = for each srcStride in dstTy.srcStrides:
-/// !srcStrides.isDynamic()
-/// ? srcStrides
-/// : extract_strided_metadata(src).strides[i]
-/// ```
-///
-/// In other words, consume the `cast` and apply its effects
-/// on the offset, sizes, and strides or compute them directly from `src`.
-class ExtractStridedMetadataOpCastFolder
- : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult
- matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
- PatternRewriter &rewriter) const override {
- Value source = extractStridedMetadataOp.getSource();
- auto castOp = source.getDefiningOp<memref::CastOp>();
- if (!castOp)
- return failure();
-
- Location loc = extractStridedMetadataOp.getLoc();
- // Check if the source is suitable for extract_strided_metadata.
- SmallVector<Type> inferredReturnTypes;
- if (failed(extractStridedMetadataOp.inferReturnTypes(
- rewriter.getContext(), loc, {castOp.getSource()},
- /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
- inferredReturnTypes)))
- return rewriter.notifyMatchFailure(castOp,
- "cast source's type is incompatible");
-
- auto memrefType = cast<MemRefType>(source.getType());
- unsigned rank = memrefType.getRank();
- SmallVector<OpFoldResult> results;
- results.resize_for_overwrite(rank * 2 + 2);
-
- auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create(
- rewriter, loc, castOp.getSource());
-
- // Register the base_buffer.
- results[0] = newExtractStridedMetadata.getBaseBuffer();
-
- auto getConstantOrValue = [&rewriter](int64_t constant,
- OpFoldResult ofr) -> OpFoldResult {
- return ShapedType::isStatic(constant)
- ? OpFoldResult(rewriter.getIndexAttr(constant))
- : ofr;
- };
-
- auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
- assert(sourceStrides.size() == rank && "unexpected number of strides");
-
- // Register the new offset.
- results[1] =
- getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
-
- const unsigned sizeStartIdx = 2;
- const unsigned strideStartIdx = sizeStartIdx + rank;
- ArrayRef<int64_t> sourceSizes = memrefType.getShape();
-
- SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
- SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
- for (unsigned i = 0; i < rank; ++i) {
- results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
- results[strideStartIdx + i] =
- getConstantOrValue(sourceStrides[i], strides[i]);
- }
- rewriter.replaceOp(extractStridedMetadataOp,
- getValueOrCreateConstantIndexOp(rewriter, loc, results));
- return success();
- }
-};
-
/// Replace `base, offset, sizes, strides = extract_strided_metadata(
/// memory_space_cast(src) to dstTy)`
/// with
@@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns(
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
ExtractStridedMetadataOpSubviewFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
@@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns(
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
ExtractStridedMetadataOpReinterpretCastFolder,
- ExtractStridedMetadataOpCastFolder,
ExtractStridedMetadataOpMemorySpaceCastFolder,
ExtractStridedMetadataOpAssumeAlignmentFolder,
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
new file mode 100644
index 0000000..f305068
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIROpenACCAnalysis
+ OpenACCSupport.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIROpenACCDialect
+ MLIROpenACCUtils
+ MLIRSupport
+)
+
diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
new file mode 100644
index 0000000..f6b4534
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp
@@ -0,0 +1,26 @@
+//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the OpenACCSupport analysis interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
+
+namespace mlir {
+namespace acc {
+
+std::string OpenACCSupport::getVariableName(Value v) {
+ if (impl)
+ return impl->getVariableName(v);
+ return acc::getVariableName(v);
+}
+
+} // namespace acc
+} // namespace mlir
diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
index 7117520..e8a916e 100644
--- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Analysis)
add_subdirectory(IR)
add_subdirectory(Utils)
add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 5ca0100..ca46629 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() {
}
//===----------------------------------------------------------------------===//
+// FirstprivateMapInitialOp
+//===----------------------------------------------------------------------===//
+LogicalResult acc::FirstprivateMapInitialOp::verify() {
+ if (getDataClause() != acc::DataClause::acc_firstprivate)
+ return emitError("data clause associated with firstprivate operation must "
+ "match its intent");
+ if (failed(checkVarAndVarType(*this)))
+ return failure();
+ if (failed(checkNoModifier(*this)))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
LogicalResult acc::ReductionOp::verify() {
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 1223325..89adda82 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtils.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/TypeSwitch.h"
mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
@@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) {
pointerLikeTy.getElementType());
return typeCategory;
}
+
+std::string mlir::acc::getVariableName(mlir::Value v) {
+ Value current = v;
+
+ // Walk through view operations until a name is found or can't go further
+ while (Operation *definingOp = current.getDefiningOp()) {
+ // Check for `acc.var_name` attribute
+ if (auto varNameAttr =
+ definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName()))
+ return varNameAttr.getName().str();
+
+ // If it is a data entry operation, get name via getVarName
+ if (isa<ACC_DATA_ENTRY_OPS>(definingOp))
+ if (auto name = acc::getVarName(definingOp))
+ return name->str();
+
+ // If it's a view operation, continue to the source
+ if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) {
+ current = viewOp.getViewSource();
+ continue;
+ }
+
+ break;
+ }
+
+ return "";
+}
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d34..f3c02da 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(Transforms)
+
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
new file mode 100644
index 0000000..b9b8eda
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIROpenMPTransforms
+ OpenMPOffloadPrivatizationPrepare.cpp
+
+ DEPENDS
+ MLIROpenMPPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRFuncDialect
+ MLIRLLVMDialect
+ MLIROpenMPDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
new file mode 100644
index 0000000..a9125ec
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -0,0 +1,445 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+ : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+
+ // In this pass, we make host-allocated privatized variables persist for
+ // deferred target tasks by copying them to the heap. Once the target task
+ // is done, this heap memory is freed. Since all of this happens on the host
+ // we can skip device modules.
+ auto offloadModuleInterface =
+ dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+ return;
+
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return;
+ IRRewriter rewriter(&getContext());
+ OperandRange privateVars = targetOp.getPrivateVars();
+ SmallVector<mlir::Value> newPrivVars;
+ Value fakeDependVar;
+ omp::TaskOp cleanupTaskOp;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ Value privVar = std::get<0>(privVarSymPair);
+ Attribute privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+ auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // For deferred target tasks (!$omp target nowait), we need to keep
+ // a copy of the original, i.e. host variable being privatized so
+ // that it is available when the target task is eventually executed.
+ // We do this by first allocating as much heap memory as is needed by
+ // the original variable. Then, we use the init and copy regions of the
+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
+ // allocated copy.
+ // After the target task is done, we need to use the dealloc region
+ // of the privatizer to clean up everything. We also need to free
+ // the heap memory we allocated. But due to the deferred nature
+ // of the target task, we cannot simply deallocate right after the
+ // omp.target operation else we may end up freeing memory before
+ // its eventual use by the target task. So, we create a dummy
+ // dependence between the target task and new omp.task. In the omp.task,
+ // we do all the cleanup. So, we end up with the following structure
+ //
+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
+ // ...
+ // omp.terminator
+ // }
+ // omp.task depend(in: fakeDependVar) {
+ // /*cleanup_code*/
+ // omp.terminator
+ // }
+ // fakeDependVar is the address of the first heap-allocated copy of the
+ // host variable being privatized.
+
+ bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+ // should have mapped the pointer to the boxchar so use that as varPtr.
+ Value varPtr = mapInfoOp.getVarPtr();
+ Type varType = mapInfoOp.getVarType();
+ bool isPrivatizedByValue =
+ !isa<LLVM::LLVMPointerType>(privVar.getType());
+
+ assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+ Value heapMem =
+ allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+ if (!heapMem)
+ targetOp.emitError(
+ "Unable to allocate heap memory when trying to move "
+ "a private variable out of the stack and into the "
+ "heap for use by a deferred target task");
+
+ if (needsCleanupTask && !fakeDependVar)
+ fakeDependVar = heapMem;
+
+ // The types of private vars should match before and after the
+ // transformation. In particular, if the type is a pointer,
+ // simply record the newly allocated malloc location as the
+ // new private variable. If, however, the type is not a pointer
+ // then, we need to load the value from the newly allocated
+ // location. We'll insert that load later after we have updated
+ // the malloc'd location with the contents of the original
+ // variable.
+ if (!isPrivatizedByValue)
+ newPrivVars.push_back(heapMem);
+
+ // We now need to copy the original private variable into the newly
+ // allocated location in the heap.
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ DenseSet<Operation *> users;
+ if (varPtrDefiningOp) {
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+ } else {
+ auto blockArg = cast<BlockArgument>(varPtr);
+ users.insert(blockArg.user_begin(), blockArg.user_end());
+ }
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOp);
+ for (auto member : mapInfoOp.getMembers()) {
+ omp::MapInfoOp memberMap =
+ cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (usesVarPtr(memberMap))
+ chainOfOps.push_back(memberMap);
+ if (memberMap.getVarPtrPtr()) {
+ Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
+ if (defOp && usesVarPtr(defOp))
+ chainOfOps.push_back(defOp);
+ }
+ }
+
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ return dom.dominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+
+ Operation *firstOp = chainOfOps.front();
+ Location loc = firstOp->getLoc();
+
+ // Create a llvm.func for 'region' that is marked always_inline and call
+ // it.
+ auto createAlwaysInlineFuncAndCallIt =
+ [&](Region &region, llvm::StringRef funcName,
+ llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
+ assert(!region.empty() && "region cannot be empty");
+ LLVM::LLVMFuncOp func = createFuncOpForRegion(
+ loc, mod, region, funcName, rewriter, returnsValue);
+ auto call = LLVM::CallOp::create(rewriter, loc, func, args);
+ return call.getResult();
+ };
+
+ Value moldArg, newArg;
+ if (isPrivatizedByValue) {
+ moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr);
+ newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem);
+ } else {
+ moldArg = varPtr;
+ newArg = heapMem;
+ }
+
+ Value initializedVal;
+ if (!privatizer.getInitRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getInitRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+ {moldArg, newArg}, /*returnsValue=*/true);
+ else
+ initializedVal = newArg;
+
+ if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getCopyRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+ {moldArg, initializedVal}, /*returnsValue=*/true);
+
+ if (isPrivatizedByValue)
+ (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem);
+
+ // clone origOp, replace all uses of varPtr with heapMem and
+ // erase origOp.
+ auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ rewriter.modifyOpInPlace(clonedOp, [&]() {
+ clonedOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ rewriter.eraseOp(origOp);
+ return clonedOp;
+ };
+
+ // Now that we have set up the heap-allocated copy of the private
+ // variable, rewrite all the uses of the original variable with
+ // the heap-allocated variable.
+ rewriter.setInsertionPoint(targetOp);
+ mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp));
+ rewriter.setInsertionPoint(mapInfoOp);
+
+ // Fix any members that may use varPtr to now use heapMem
+ for (auto member : mapInfoOp.getMembers()) {
+ auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (!usesVarPtr(memberMapInfoOp))
+ continue;
+ memberMapInfoOp =
+ cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp));
+ rewriter.setInsertionPoint(memberMapInfoOp);
+
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+ rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+ }
+ }
+
+ // If the type of the private variable is not a pointer,
+ // which is typically the case with !fir.boxchar types, then
+ // we need to ensure that the new private variable is also
+ // not a pointer. Insert a load from heapMem right before
+ // targetOp.
+ if (isPrivatizedByValue) {
+ rewriter.setInsertionPoint(targetOp);
+ auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(),
+ varType, heapMem);
+ newPrivVars.push_back(newPrivVar);
+ }
+
+ // Deallocate
+ if (needsCleanupTask) {
+ if (!cleanupTaskOp) {
+ assert(fakeDependVar &&
+ "Need a valid value to set up a dependency");
+ rewriter.setInsertionPointAfter(targetOp);
+ omp::TaskOperands taskOperands;
+ auto inDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
+ taskOperands.dependKinds.push_back(inDepend);
+ taskOperands.dependVars.push_back(fakeDependVar);
+ cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
+ Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
+ rewriter.setInsertionPointToEnd(taskBlock);
+ omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc());
+ }
+ rewriter.setInsertionPointToStart(
+ &*cleanupTaskOp.getRegion().getBlocks().begin());
+ (void)createAlwaysInlineFuncAndCallIt(
+ privatizer.getDeallocRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
+ .str(),
+ {initializedVal}, /*returnsValue=*/false);
+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ LLVM::lookupOrCreateFreeFn(rewriter, mod);
+ assert(llvm::succeeded(freeFunc) &&
+ "Could not find free in the module");
+ (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(),
+ ValueRange{heapMem});
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+ if (fakeDependVar) {
+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
+ SmallVector<Attribute> newDependKinds;
+ if (!targetOp.getDependVars().empty()) {
+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
+ assert(dependKinds && "bad depend clause in omp::TargetOp");
+ llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
+ }
+ newDependKinds.push_back(outDepend);
+ ArrayAttr newDependKindsAttr =
+ ArrayAttr::get(rewriter.getContext(), newDependKinds);
+ targetOp.getDependVarsMutable().append(fakeDependVar);
+ targetOp.setDependKindsAttr(newDependKindsAttr);
+ }
+ rewriter.setInsertionPoint(targetOp);
+ targetOp.getPrivateVarsMutable().clear();
+ targetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
+ llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
+ ModuleOp mod, IRRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ BlockArgument blockArg;
+ if (!definingOp) {
+ blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
+ rewriter.setInsertionPointToStart(blockArg.getParentBlock());
+ } else {
+ rewriter.setInsertionPoint(definingOp);
+ }
+ Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+
+ const DataLayout &dl = DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+
+ Value sizeBytes = LLVM::ConstantOp::create(
+ rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ // Create a function for srcRegion and attribute it to be always_inline.
+ // The big assumption here is that srcRegion is one of init, copy or dealloc
+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
+ // to either be the same as the types of the two arguments of the region (for
+ // init and copy regions) or void as would be the case for dealloc regions.
+ LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
+ Region &srcRegion,
+ llvm::StringRef funcName,
+ IRRewriter &rewriter,
+ bool returnsValue = false) {
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+ Region clonedRegion;
+ IRMapping mapper;
+ srcRegion.cloneInto(&clonedRegion, mapper);
+
+ SmallVector<Type> paramTypes;
+ llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
+ Type resultType = returnsValue
+ ? srcRegion.getArgument(0).getType()
+ : LLVM::LLVMVoidType::get(rewriter.getContext());
+ LLVM::LLVMFunctionType funcType =
+ LLVM::LLVMFunctionType::get(resultType, paramTypes);
+
+ LLVM::LLVMFuncOp func =
+ LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
+ func.setAlwaysInline(true);
+ rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
+ func.getRegion().end());
+ for (auto &block : func.getRegion().getBlocks()) {
+ if (isa<omp::YieldOp>(block.getTerminator())) {
+ omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
+ yieldOp.getOperands());
+ }
+ }
+ return func;
+ }
+};
+} // namespace
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index a9da6c2..9bd13f3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -2490,8 +2490,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantTrue)
- constantTrue = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
+ constantTrue = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantTrue); });
@@ -2500,8 +2500,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
changed = true;
if (!constantFalse)
- constantFalse = rewriter.create<arith::ConstantOp>(
- op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
+ constantFalse = arith::ConstantOp::create(
+ rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
rewriter.modifyOpInPlace(use.getOwner(),
[&]() { use.set(constantFalse); });
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2..335ca1a 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index caf8016..99b7cda 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
!outputTy.hasStaticShape())
return {};
- if (inputTy.getDimSize(getAxis()) == 1)
- return DenseElementsAttr::get(outputTy, 0);
+ const Type outputElementTy = getElementTypeOrSelf(outputTy);
+ if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) {
+ const auto outputElemIntTy = cast<IntegerType>(outputElementTy);
+ const APInt zero = APInt::getZero(outputElemIntTy.getWidth());
+ return DenseElementsAttr::get(outputTy, zero);
+ }
return {};
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 8f46ad6..ef49c86 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final
for (int64_t i = 0; i < origNumElems; ++i)
promoteMask[i] = i;
- Value promotedInput = rewriter.create<vector::ShuffleOp>(
- shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
- promoteMask);
+ Value promotedInput =
+ vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType,
+ inputToPromote, inputToPromote, promoteMask);
// Create the final shuffle with the promoted inputs.
Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7c019e7..8b5e950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -341,13 +341,18 @@ private:
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
+/// distributed dimensions. If the number of results is zero there is no
+/// distribution (i.e. original type is returned).
+/// Otherwise, The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
+ // If the map has zero results, return the original type.
+ if (map.getNumResults() == 0)
+ return originalType;
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 1599ae9..24e9095 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
OpBuilder &builder) {
auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
- return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+ return ArithOp::create(builder, loc, aVal, bVal).getResult();
}
// a helper utility to perform division operation on OpFoldResult and int64_t.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 26770b3..d09dc19 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
return AffineMap::get(val.getContext());
// Get the layout of the vector type.
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val);
- // If no layout is specified, assume the inner most dimension is distributed
- // for now.
+ // If no layout is specified, that means no distribution.
if (!layout)
- return AffineMap::getMultiDimMapWithTargets(
- vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext());
+ return AffineMap::getMultiDimMapWithTargets(vecRank, {},
+ val.getContext());
+ // Expecting vector and layout rank to match.
+ assert(layout.getRank() == vecRank &&
+ "Expecting vector and layout rank to match");
+ // A dimension is distributed only if layout suggests there are
+ // multiple lanes assigned for this dimension and the shape can be evenly
+ // distributed to those lanes.
SmallVector<unsigned int> distributedDims;
for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) {
- if (v > 1)
+ if (v > 1 && vecType.getShape()[i] % v == 0)
distributedDims.push_back(i);
}
return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims,
@@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
auto warpReduction = [](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
// First reduce on a single thread to get per lane reduction value.
- Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input);
+ Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
- Value shuffled =
- builder
- .create<gpu::ShuffleOp>(loc, laneVal, i,
- /*width=*/size,
- /*mode=*/gpu::ShuffleMode::XOR)
- .getShuffleResult();
+ Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i,
+ /*width=*/size,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 31a967d..9fc5ad9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
baseTileValues);
- auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
+ auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
// Get subgroup id
Value sgId =
@@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
SmallVector<Value, 2> strideConsts;
strideConsts.push_back(
- rewriter.create<arith::ConstantIndexOp>(loc, colStride));
+ arith::ConstantIndexOp::create(rewriter, loc, colStride));
if (rows > 1)
strideConsts.insert(
strideConsts.begin(),
- rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
+ arith::ConstantIndexOp::create(rewriter, loc, rowStride));
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
- Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
for (size_t i = 0; i < strideConsts.size(); ++i) {
- Value mul = rewriter.create<arith::MulIOp>(
- loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
- mulOffset = rewriter.create<arith::AddIOp>(
- loc, rewriter.getIndexType(), mulOffset, mul);
+ Value mul =
+ arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
+ offsets[i], strideConsts[i]);
+ mulOffset = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), mulOffset, mul);
}
// Broadcast to baseConstVec size
- auto bcastOffset = rewriter.create<vector::BroadcastOp>(
- loc, baseConstVec.getType(), mulOffset);
+ auto bcastOffset = vector::BroadcastOp::create(
+ rewriter, loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
@@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp
SmallVector<Value> newShapeCastOps;
for (auto src : adaptor.getSource()) {
- auto newShapeCast =
- rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ newResultType, src);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
@@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp
SmallVector<Value> newReductions;
for (auto sgSrc : adaptor.getSource()) {
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
- op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0],
- op.getReductionDims());
+ auto newOp = vector::MultiDimReductionOp::create(
+ rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
+ adaptor.getAcc()[0], op.getReductionDims());
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
index dd413d2de..d7e321a 100644
--- a/mlir/lib/RegisterAllPasses.cpp
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -33,6 +33,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h"
#include "mlir/Dialect/Quant/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
@@ -80,6 +81,7 @@ void mlir::registerAllPasses() {
memref::registerMemRefPasses();
shard::registerShardPasses();
ml_program::registerMLProgramPasses();
+ omp::registerOpenMPPasses();
quant::registerQuantPasses();
registerSCFPasses();
registerShapePasses();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b851414..f284540 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -357,14 +357,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("priority");
};
auto checkPrivate = [&todo](auto op, LogicalResult &result) {
- if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
- // Privatization is supported only for included target tasks.
- if (!op.getPrivateVars().empty() && op.getNowait())
- result = todo("privatization for deferred target tasks");
- } else {
- if (!op.getPrivateVars().empty() || op.getPrivateSyms())
- result = todo("privatization");
- }
+ if (!op.getPrivateVars().empty() || op.getPrivateSyms())
+ result = todo("privatization");
};
auto checkReduction = [&todo](auto op, LogicalResult &result) {
if (isa<omp::TeamsOp>(op))
@@ -451,7 +445,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkDevice(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
- checkPrivate(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d9ad8fb..6492708 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -702,8 +702,8 @@ spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
// RAII guard to reset the insertion point to previous value when done.
OpBuilder::InsertionGuard insertionGuard(opBuilder);
opBuilder.setInsertionPoint(graphARM);
- opBuilder.create<spirv::GraphEntryPointARMOp>(
- unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ spirv::GraphEntryPointARMOp::create(
+ opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
opBuilder.getArrayAttr(interface));
return success();
@@ -736,7 +736,7 @@ spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
std::string graphName = getGraphSymbol(graphID);
auto graphOp =
- opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
+ spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
curGraph = graphMap[graphID] = graphOp;
Block *entryBlock = graphOp.addEntryBlock();
LLVM_DEBUG({
@@ -844,7 +844,7 @@ spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
LogicalResult
spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
// Create GraphOutputsARM instruction.
- opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+ spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
// Process OpGraphEndARM.
if (!operands.empty()) {
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
index 366ba8f..048e964 100644
--- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -406,7 +406,7 @@ private:
auto returnOperands = popOperands(resTypes);
if (failed(returnOperands))
return failure();
- builder.create<BlockReturnOp>(opLoc, *returnOperands);
+ BlockReturnOp::create(builder, opLoc, *returnOperands);
LDBG() << "end of parsing of a block";
return bodyParsingRes->endingByte;
}
@@ -1000,7 +1000,7 @@ parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) {
builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
builder.setInsertionPointToEnd(curBlock);
auto blockOp =
- builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor);
+ OpToCreate::create(builder, *currentOpLoc, *inputOps, successor);
auto *blockBody = blockOp.createBlock();
if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp)))
return failure();
@@ -1047,8 +1047,8 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
auto *successor =
builder.createBlock(curRegion, curRegion->end(), resTypes, locations);
builder.setInsertionPointToEnd(curBlock);
- auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(),
- *inputOps, successor);
+ auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(),
+ *inputOps, successor);
auto *ifEntryBlock = ifOp.createIfBlock();
constexpr auto ifElseFilter =
ByteSequence<WasmBinaryEncoding::endByte,
@@ -1091,9 +1091,9 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
auto branchArgs = popOperands(inputTypes);
if (failed(branchArgs))
return failure();
- builder.create<BranchIfOp>(*currentOpLoc, condition->front(),
- builder.getUI32IntegerAttr(*level), *branchArgs,
- elseBlock);
+ BranchIfOp::create(builder, *currentOpLoc, condition->front(),
+ builder.getUI32IntegerAttr(*level), *branchArgs,
+ elseBlock);
builder.setInsertionPointToStart(elseBlock);
return {*branchArgs};
}
@@ -1115,7 +1115,7 @@ ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>(
if (failed(inOperands))
return failure();
auto callOp =
- builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands);
+ FuncCallOp::create(builder, loc, resTypes, callee.symbol, *inOperands);
return {callOp.getResults()};
}
@@ -1391,8 +1391,8 @@ inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder,
auto operand = popOperands(intype);
if (failed(operand))
return failure();
- auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(),
- extraArgs...);
+ auto op = opType::create(builder, *currentOpLoc, outType, operand->front(),
+ extraArgs...);
LDBG() << "Built operation: " << op;
return {{op.getResult()}};
}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index e79b51a..20ed3ab 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -894,9 +894,16 @@ if(NOT LLVM_ENABLE_IDE)
)
endif()
+set(_mlir_python_stubgen_enabled ON)
# Stubgen doesn't work when cross-compiling (stubgen will run in the host interpreter and then fail
# to find the extension module for the host arch).
-if(NOT CMAKE_CROSSCOMPILING)
+# Note: Stubgen requires some extra handling to work properly when sanitizers are enabled,
+# so we skip running it in that case now.
+if(CMAKE_CROSSCOMPILING OR (NOT LLVM_USE_SANITIZER STREQUAL ""))
+ set(_mlir_python_stubgen_enabled OFF)
+endif()
+
+if(_mlir_python_stubgen_enabled)
# _mlir stubgen
# Note: All this needs to come before add_mlir_python_modules(MLIRPythonModules so that the install targets for the
# generated type stubs get created.
@@ -985,7 +992,7 @@ endif()
################################################################################
set(_declared_sources MLIRPythonSources MLIRPythonExtension.RegisterEverything)
-if(NOT CMAKE_CROSSCOMPILING)
+if(_mlir_python_stubgen_enabled)
list(APPEND _declared_sources MLIRPythonExtension.Core.type_stub_gen)
endif()
@@ -998,7 +1005,7 @@ add_mlir_python_modules(MLIRPythonModules
COMMON_CAPI_LINK_LIBS
MLIRPythonCAPI
)
-if(NOT CMAKE_CROSSCOMPILING)
+if(_mlir_python_stubgen_enabled)
add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}")
if(MLIR_INCLUDE_TESTS)
add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}")
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7160b52..3130902 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -901,6 +901,132 @@ func.func @scope_merge_without_terminator() {
// -----
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// When we apply the transformation the resulting offset, sizes and strides
+// should come straight from the inputs of the cast.
+// Additionally the folder on extract_strided_metadata should propagate the
+// static information.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast
+// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
+//
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
+func.func @extract_strided_metadata_of_cast(
+ %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<3x?xi32, strided<[4, ?], offset: ?>> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we simplify extract_strided_metadata of cast
+// when the source of the cast is compatible with what
+// `extract_strided_metadata`s accept.
+//
+// Same as extract_strided_metadata_of_cast but with constant sizes and strides
+// in the destination type.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
+// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
+//
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
+// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//
+// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
+func.func @extract_strided_metadata_of_cast_w_csts(
+ %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<?x?xi32, strided<[?, ?], offset: ?>> to
+ memref<4x?xi32, strided<[?, 18], offset: 25>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
+// Check that we don't simplify extract_strided_metadata of
+// cast when the source of the cast is unranked.
+// Unranked memrefs cannot feed into extract_strided_metadata operations.
+// Note: Technically we could still fold the sizes and strides.
+//
+// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
+// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
+//
+// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
+// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
+//
+// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
+func.func @extract_strided_metadata_of_cast_unranked(
+ %arg : memref<*xi32>)
+ -> (memref<i32>, index,
+ index, index,
+ index, index) {
+
+ %cast =
+ memref.cast %arg :
+ memref<*xi32> to
+ memref<?x?xi32, strided<[?, ?], offset: ?>>
+
+ %base, %base_offset, %sizes:2, %strides:2 =
+ memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
+ -> memref<i32>, index,
+ index, index,
+ index, index
+
+ return %base, %base_offset,
+ %sizes#0, %sizes#1,
+ %strides#0, %strides#1 :
+ memref<i32>, index,
+ index, index,
+ index, index
+}
+
+// -----
+
// CHECK-LABEL: func @reinterpret_noop
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
// CHECK-NEXT: return %[[ARG]]
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 1e6b011..18cdfb7 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -1378,133 +1378,6 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
// -----
-// Check that we simplify extract_strided_metadata of cast
-// when the source of the cast is compatible with what
-// `extract_strided_metadata`s accept.
-//
-// When we apply the transformation the resulting offset, sizes and strides
-// should come straight from the inputs of the cast.
-// Additionally the folder on extract_strided_metadata should propagate the
-// static information.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast
-// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
-//
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//
-// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
-func.func @extract_strided_metadata_of_cast(
- %arg : memref<3x?xi32, strided<[4, ?], offset:?>>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<3x?xi32, strided<[4, ?], offset: ?>> to
- memref<?x?xi32, strided<[?, ?], offset: ?>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-// -----
-
-// Check that we simplify extract_strided_metadata of cast
-// when the source of the cast is compatible with what
-// `extract_strided_metadata`s accept.
-//
-// Same as extract_strided_metadata_of_cast but with constant sizes and strides
-// in the destination type.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
-// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
-//
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
-// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
-// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//
-// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
-func.func @extract_strided_metadata_of_cast_w_csts(
- %arg : memref<?x?xi32, strided<[?, ?], offset:?>>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<?x?xi32, strided<[?, ?], offset: ?>> to
- memref<4x?xi32, strided<[?, 18], offset: 25>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-// -----
-
-// Check that we don't simplify extract_strided_metadata of
-// cast when the source of the cast is unranked.
-// Unranked memrefs cannot feed into extract_strided_metadata operations.
-// Note: Technically we could still fold the sizes and strides.
-//
-// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
-// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
-//
-// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
-// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
-//
-// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
-func.func @extract_strided_metadata_of_cast_unranked(
- %arg : memref<*xi32>)
- -> (memref<i32>, index,
- index, index,
- index, index) {
-
- %cast =
- memref.cast %arg :
- memref<*xi32> to
- memref<?x?xi32, strided<[?, ?], offset: ?>>
-
- %base, %base_offset, %sizes:2, %strides:2 =
- memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>>
- -> memref<i32>, index,
- index, index,
- index, index
-
- return %base, %base_offset,
- %sizes#0, %sizes#1,
- %strides#0, %strides#1 :
- memref<i32>, index,
- index, index,
- index, index
-}
-
-
-// -----
-
memref.global "private" @dynamicShmem : memref<0xf16,3>
// CHECK-LABEL: func @zero_sized_memred
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 8713689..77d18da 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -2200,3 +2200,46 @@ acc.private.recipe @privatization_memref_slice : memref<10x10xf32> init {
acc.yield %result : memref<10x10xf32>
}
+
+// -----
+
+func.func @test_firstprivate_map(%arg0: memref<10xf32>) {
+ // Map the function argument using firstprivate_map to enable
+ // moving to accelerator but prevent any present counter updates.
+ %mapped = acc.firstprivate_map varPtr(%arg0 : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+
+ acc.parallel {
+ // Allocate a local variable inside the parallel region to represent
+ // materialized privatization.
+ %local = memref.alloca() : memref<10xf32>
+
+ // Initialize the local variable with the mapped firstprivate value
+ %c0 = arith.constant 0 : index
+ %c10 = arith.constant 10 : index
+ %c1 = arith.constant 1 : index
+
+ scf.for %i = %c0 to %c10 step %c1 {
+ %val = memref.load %mapped[%i] : memref<10xf32>
+ memref.store %val, %local[%i] : memref<10xf32>
+ }
+
+ acc.yield
+ }
+
+ return
+}
+
+// CHECK-LABEL: func @test_firstprivate_map
+// CHECK-NEXT: %[[MAPPED:.*]] = acc.firstprivate_map varPtr(%{{.*}} : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+// CHECK-NEXT: acc.parallel {
+// CHECK-NEXT: %[[LOCAL:.*]] = memref.alloca() : memref<10xf32>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: scf.for %{{.*}} = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK-NEXT: %{{.*}} = memref.load %[[MAPPED]][%{{.*}}] : memref<10xf32>
+// CHECK-NEXT: memref.store %{{.*}}, %[[LOCAL]][%{{.*}}] : memref<10xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir b/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir
new file mode 100644
index 0000000..af52bef
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/support-analysis-varname.mlir
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -split-input-file -test-acc-support | FileCheck %s
+
+// Test with direct variable names
+func.func @test_direct_var_name() {
+ // Create a memref with acc.var_name attribute
+ %0 = memref.alloca() {acc.var_name = #acc.var_name<"my_variable">} : memref<10xi32>
+
+ %1 = memref.cast %0 {test.var_name} : memref<10xi32> to memref<10xi32>
+
+ // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32>
+ // CHECK-NEXT: getVariableName="my_variable"
+
+ return
+}
+
+// -----
+
+// Test through memref.cast
+func.func @test_through_cast() {
+ // Create a 5x2 memref with acc.var_name attribute
+ %0 = memref.alloca() {acc.var_name = #acc.var_name<"casted_variable">} : memref<5x2xi32>
+
+ // Cast to dynamic dimensions
+ %1 = memref.cast %0 : memref<5x2xi32> to memref<?x?xi32>
+
+ // Mark with test attribute - should find name through cast
+ %2 = memref.cast %1 {test.var_name} : memref<?x?xi32> to memref<5x2xi32>
+
+ // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<?x?xi32> to memref<5x2xi32>
+ // CHECK-NEXT: getVariableName="casted_variable"
+
+ return
+}
+
+// -----
+
+// Test with no variable name
+func.func @test_no_var_name() {
+ // Create a memref without acc.var_name attribute
+ %0 = memref.alloca() : memref<10xi32>
+
+ // Mark with test attribute - should find empty string
+ %1 = memref.cast %0 {test.var_name} : memref<10xi32> to memref<10xi32>
+
+ // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32>
+ // CHECK-NEXT: getVariableName=""
+
+ return
+}
+
+// -----
+
+// Test through multiple casts
+func.func @test_multiple_casts() {
+ // Create a memref with acc.var_name attribute
+ %0 = memref.alloca() {acc.var_name = #acc.var_name<"multi_cast">} : memref<10xi32>
+
+ // Multiple casts
+ %1 = memref.cast %0 : memref<10xi32> to memref<?xi32>
+ %2 = memref.cast %1 : memref<?xi32> to memref<10xi32>
+
+ // Mark with test attribute - should find name through multiple casts
+ %3 = memref.cast %2 {test.var_name} : memref<10xi32> to memref<10xi32>
+
+ // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xi32> to memref<10xi32>
+ // CHECK-NEXT: getVariableName="multi_cast"
+
+ return
+}
+
+// -----
+
+// Test with acc.copyin operation
+func.func @test_copyin_name() {
+ // Create a memref
+ %0 = memref.alloca() : memref<10xf32>
+
+ // Create an acc.copyin operation with a name
+ %1 = acc.copyin varPtr(%0 : memref<10xf32>) -> memref<10xf32> {name = "input_data"}
+
+ // Mark with test attribute - should find name from copyin operation
+ %2 = memref.cast %1 {test.var_name} : memref<10xf32> to memref<?xf32>
+
+ // CHECK: op=%{{.*}} = memref.cast %{{.*}} {test.var_name} : memref<10xf32> to memref<?xf32>
+ // CHECK-NEXT: getVariableName="input_data"
+
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare-by-value.mlir b/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare-by-value.mlir
new file mode 100644
index 0000000..8972a08
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare-by-value.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt --mlir-disable-threading -omp-offload-privatization-prepare --split-input-file %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.mangling_mode" = "e", "dlti.legal_int_widths" = array<i32: 8, 16, 32, 64>, "dlti.stack_alignment" = 128 : i64>} {
+ llvm.func @free(!llvm.ptr)
+ llvm.func @malloc(i64) -> !llvm.ptr
+
+ omp.private {type = firstprivate} @private_eye : i32 copy {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.load %arg0 : !llvm.ptr -> i32
+ llvm.store %0, %arg1 : i32, !llvm.ptr
+ omp.yield(%arg1 : !llvm.ptr)
+ }
+ omp.private {type = firstprivate} @boxchar_firstprivate : !llvm.struct<(ptr, i64)> init {
+ ^bb0(%arg0: !llvm.struct<(ptr, i64)>, %arg1: !llvm.struct<(ptr, i64)>):
+ %0 = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr, i64)>
+ %1 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr, i64)>
+ %8 = llvm.call @malloc(%1) {bindc_name = "", uniq_name = ""} : (i64) -> !llvm.ptr
+ %9 = llvm.mlir.undef : !llvm.struct<(ptr, i64)>
+ %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, i64)>
+ %11 = llvm.insertvalue %1, %10[1] : !llvm.struct<(ptr, i64)>
+ omp.yield(%11 : !llvm.struct<(ptr, i64)>)
+ } copy {
+ ^bb0(%arg0: !llvm.struct<(ptr, i64)>, %arg1: !llvm.struct<(ptr, i64)>):
+ %3 = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr, i64)>
+ %4 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr, i64)>
+ %5 = llvm.extractvalue %arg1[0] : !llvm.struct<(ptr, i64)>
+ %6 = llvm.extractvalue %arg1[1] : !llvm.struct<(ptr, i64)>
+ %7 = llvm.icmp "slt" %6, %4 : i64
+ %8 = llvm.select %7, %6, %4 : i1, i64
+ "llvm.intr.memmove"(%5, %3, %8) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
+ omp.yield(%arg1 : !llvm.struct<(ptr, i64)>)
+ } dealloc {
+ ^bb0(%arg0: !llvm.struct<(ptr, i64)>):
+ %0 = llvm.extractvalue %arg0[0] : !llvm.struct<(ptr, i64)>
+ %1 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr, i64)>
+ llvm.call @free(%0) : (!llvm.ptr) -> ()
+ omp.yield
+ }
+
+ llvm.func @target_boxchar_(%arg0: !llvm.ptr {fir.bindc_name = "l"}) attributes {fir.internal_name = "_QPtarget_boxchar", frame_pointer = #llvm.framePointerKind<all>, target_cpu = "x86-64"} {
+ %0 = llvm.mlir.constant(1 : i64) : i64
+ %1 = llvm.alloca %0 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+ %2 = llvm.mlir.constant(1 : i64) : i64
+ %3 = llvm.alloca %2 x !llvm.struct<(ptr, i64)> : (i64) -> !llvm.ptr
+ %4 = llvm.mlir.constant(1 : index) : i64
+ %5 = llvm.mlir.constant(0 : index) : i64
+ %6 = llvm.mlir.constant(0 : i32) : i32
+ %7 = llvm.mlir.constant(1 : i64) : i64
+ %8 = llvm.mlir.constant(1 : i64) : i64
+ %9 = llvm.load %arg0 : !llvm.ptr -> i32
+ %10 = llvm.icmp "sgt" %9, %6 : i32
+ %11 = llvm.select %10, %9, %6 : i1, i32
+ %12 = llvm.mlir.constant(1 : i64) : i64
+ %13 = llvm.sext %11 : i32 to i64
+ %14 = llvm.alloca %13 x i8 {bindc_name = "char_var"} : (i64) -> !llvm.ptr
+ %15 = llvm.mlir.undef : !llvm.struct<(ptr, i64)>
+ %16 = llvm.sext %11 : i32 to i64
+ %17 = llvm.insertvalue %14, %15[0] : !llvm.struct<(ptr, i64)>
+ %18 = llvm.insertvalue %16, %17[1] : !llvm.struct<(ptr, i64)>
+ llvm.store %18, %3 : !llvm.struct<(ptr, i64)>, !llvm.ptr
+ %19 = llvm.load %3 : !llvm.ptr -> !llvm.struct<(ptr, i64)>
+ %20 = llvm.extractvalue %19[0] : !llvm.struct<(ptr, i64)>
+ %21 = llvm.extractvalue %19[1] : !llvm.struct<(ptr, i64)>
+ %22 = llvm.sub %21, %4 : i64
+ %23 = omp.map.bounds lower_bound(%5 : i64) upper_bound(%22 : i64) extent(%21 : i64) stride(%4 : i64) start_idx(%5 : i64) {stride_in_bytes = true}
+ %24 = llvm.getelementptr %3[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64)>
+ %25 = omp.map.info var_ptr(%3 : !llvm.ptr, i8) map_clauses(implicit, to) capture(ByRef) var_ptr_ptr(%24 : !llvm.ptr) bounds(%23) -> !llvm.ptr
+ %26 = omp.map.info var_ptr(%3 : !llvm.ptr, !llvm.struct<(ptr, i64)>) map_clauses(to) capture(ByRef) members(%25 : [0] : !llvm.ptr) -> !llvm.ptr
+ %27 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+ omp.target nowait map_entries(%26 -> %arg1, %27 -> %arg2, %25 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@boxchar_firstprivate %18 -> %arg4 [map_idx=0], @private_eye %1 -> %arg5 [map_idx=1] : !llvm.struct<(ptr, i64)>, !llvm.ptr) {
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+// CHECK-LABEL: llvm.func @target_boxchar_(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {fir.bindc_name = "l"}) attributes {fir.internal_name = "_QPtarget_boxchar", frame_pointer = #llvm.framePointerKind<all>, target_cpu = "x86-64"} {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.alloca %[[VAL_0]] x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(16 : i64) : i64
+// CHECK: %[[HEAP0:.*]] = llvm.call @malloc(%[[VAL_3]]) : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_2]] x !llvm.struct<(ptr, i64)> : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VAL_11:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> i32
+// CHECK: %[[VAL_12:.*]] = llvm.icmp "sgt" %[[VAL_11]], %[[VAL_8]] : i32
+// CHECK: %[[VAL_13:.*]] = llvm.select %[[VAL_12]], %[[VAL_11]], %[[VAL_8]] : i1, i32
+// CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VAL_15:.*]] = llvm.sext %[[VAL_13]] : i32 to i64
+// CHECK: %[[VAL_16:.*]] = llvm.alloca %[[VAL_15]] x i8 {bindc_name = "char_var"} : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_17:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_18:.*]] = llvm.sext %[[VAL_13]] : i32 to i64
+// CHECK: %[[VAL_19:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_17]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_20:.*]] = llvm.insertvalue %[[VAL_18]], %[[VAL_19]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: llvm.store %[[VAL_20]], %[[VAL_5]] : !llvm.struct<(ptr, i64)>, !llvm.ptr
+// CHECK: %[[VAL_21:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_21]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_23:.*]] = llvm.extractvalue %[[VAL_21]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_24:.*]] = llvm.sub %[[VAL_23]], %[[VAL_6]] : i64
+// CHECK: %[[VAL_25:.*]] = omp.map.bounds lower_bound(%[[VAL_7]] : i64) upper_bound(%[[VAL_24]] : i64) extent(%[[VAL_23]] : i64) stride(%[[VAL_6]] : i64) start_idx(%[[VAL_7]] : i64) {stride_in_bytes = true}
+// CHECK: %[[VAL_26:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_27:.*]] = llvm.load %[[HEAP0]] : !llvm.ptr -> !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_28:.*]] = llvm.call @boxchar_firstprivate_init(%[[VAL_26]], %[[VAL_27]]) : (!llvm.struct<(ptr, i64)>, !llvm.struct<(ptr, i64)>) -> !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_29:.*]] = llvm.call @boxchar_firstprivate_copy(%[[VAL_26]], %[[VAL_28]]) : (!llvm.struct<(ptr, i64)>, !llvm.struct<(ptr, i64)>) -> !llvm.struct<(ptr, i64)>
+// CHECK: llvm.store %[[VAL_29]], %[[HEAP0]] : !llvm.struct<(ptr, i64)>, !llvm.ptr
+// CHECK: %[[VAL_30:.*]] = omp.map.info var_ptr(%[[VAL_1]] : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+// CHECK: %[[VAL_31:.*]] = llvm.getelementptr %[[HEAP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_32:.*]] = omp.map.info var_ptr(%[[HEAP0]] : !llvm.ptr, i8) map_clauses(implicit, to) capture(ByRef) var_ptr_ptr(%[[VAL_31]] : !llvm.ptr) bounds(%[[VAL_25]]) -> !llvm.ptr
+// CHECK: %[[VAL_33:.*]] = omp.map.info var_ptr(%[[HEAP0]] : !llvm.ptr, !llvm.struct<(ptr, i64)>) map_clauses(to) capture(ByRef) members(%[[VAL_32]] : [0] : !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_34:.*]] = llvm.load %[[HEAP0]] : !llvm.ptr -> !llvm.struct<(ptr, i64)>
+// CHECK: omp.target depend(taskdependout -> %[[HEAP0]] : !llvm.ptr) nowait map_entries(%[[VAL_33]] -> %[[VAL_35:.*]], %[[VAL_30]] -> %[[VAL_36:.*]], %[[VAL_32]] -> %[[VAL_37:.*]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@boxchar_firstprivate %[[VAL_34]] -> %[[VAL_38:.*]] [map_idx=0], @private_eye %[[VAL_1]] -> %[[VAL_39:.*]] [map_idx=1] : !llvm.struct<(ptr, i64)>, !llvm.ptr) {
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.task depend(taskdependin -> %[[HEAP0]] : !llvm.ptr) {
+// CHECK: llvm.call @boxchar_firstprivate_dealloc(%[[VAL_29]]) : (!llvm.struct<(ptr, i64)>) -> ()
+// CHECK: llvm.call @free(%[[HEAP0]]) : (!llvm.ptr) -> ()
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: llvm.return
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @boxchar_firstprivate_init(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(ptr, i64)>,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(ptr, i64)>) -> !llvm.struct<(ptr, i64)> attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_2:.*]] = llvm.call @malloc(%[[VAL_1]]) {bindc_name = "", uniq_name = ""} : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_3]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_4]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(ptr, i64)>
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @boxchar_firstprivate_copy(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(ptr, i64)>,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.struct<(ptr, i64)>) -> !llvm.struct<(ptr, i64)> attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_4:.*]] = llvm.icmp "slt" %[[VAL_3]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_5:.*]] = llvm.select %[[VAL_4]], %[[VAL_3]], %[[VAL_1]] : i1, i64
+// CHECK: "llvm.intr.memmove"(%[[VAL_2]], %[[VAL_0]], %[[VAL_5]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> ()
+// CHECK: llvm.return %[[ARG1]] : !llvm.struct<(ptr, i64)>
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @boxchar_firstprivate_dealloc(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(ptr, i64)>) attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.extractvalue %[[ARG0]][0] : !llvm.struct<(ptr, i64)>
+// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(ptr, i64)>
+// CHECK: llvm.call @free(%[[VAL_0]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.return
+// CHECK: }
diff --git a/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir b/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir
new file mode 100644
index 0000000..0377d49
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/omp-offload-privatization-prepare.mlir
@@ -0,0 +1,201 @@
+// RUN: mlir-opt --mlir-disable-threading -omp-offload-privatization-prepare --split-input-file %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.mangling_mode" = "e", "dlti.legal_int_widths" = array<i32: 8, 16, 32, 64>, "dlti.stack_alignment" = 128 : i64>} {
+ llvm.func @free(!llvm.ptr)
+ llvm.func @malloc(i64) -> !llvm.ptr
+
+ omp.private {type = firstprivate} @firstprivatizer : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(48 : i64) : i64
+ %1 = llvm.call @malloc(%0) : (i64) -> !llvm.ptr
+ %2 = llvm.getelementptr %arg1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ llvm.store %1, %2 : !llvm.ptr, !llvm.ptr
+ omp.yield(%arg1 : !llvm.ptr)
+ } copy {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(48 : i32) : i32
+ "llvm.intr.memcpy"(%arg1, %arg0, %0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ omp.yield(%arg1 : !llvm.ptr)
+ } dealloc {
+ ^bb0(%arg0: !llvm.ptr):
+ llvm.call @free(%arg0) : (!llvm.ptr) -> ()
+ omp.yield
+ }
+ omp.private {type = firstprivate} @firstprivatizer_1 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> init {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(48 : i64) : i64
+ %1 = llvm.call @malloc(%0) : (i64) -> !llvm.ptr
+ %2 = llvm.getelementptr %arg1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ llvm.store %1, %2 : !llvm.ptr, !llvm.ptr
+ omp.yield(%arg1 : !llvm.ptr)
+ } copy {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.mlir.constant(48 : i32) : i32
+ "llvm.intr.memcpy"(%arg1, %arg0, %0) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+ omp.yield(%arg1 : !llvm.ptr)
+ } dealloc {
+ ^bb0(%arg0: !llvm.ptr):
+ llvm.call @free(%arg0) : (!llvm.ptr) -> ()
+ omp.yield
+ }
+
+ llvm.func internal @firstprivate_test(%arg0: !llvm.ptr {fir.bindc_name = "ptr0"}, %arg1: !llvm.ptr {fir.bindc_name = "ptr1"}) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %5 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %19 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "local"} : (i32) -> !llvm.ptr
+ %20 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "glocal"} : (i32) -> !llvm.ptr
+ %21 = llvm.alloca %0 x i32 {bindc_name = "i"} : (i32) -> !llvm.ptr
+ %33 = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ llvm.store %33, %19 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
+ llvm.store %33, %20 : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
+ llvm.store %0, %21 : i32, !llvm.ptr
+ %124 = omp.map.info var_ptr(%21 : !llvm.ptr, i32) map_clauses(implicit) capture(ByCopy) -> !llvm.ptr {name = "i"}
+ %150 = llvm.getelementptr %19[0, 7, %1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %151 = llvm.load %150 : !llvm.ptr -> i64
+ %152 = llvm.getelementptr %19[0, 7, %1, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %153 = llvm.load %152 : !llvm.ptr -> i64
+ %154 = llvm.getelementptr %19[0, 7, %1, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %155 = llvm.load %154 : !llvm.ptr -> i64
+ %156 = llvm.sub %153, %1 : i64
+ %157 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%156 : i64) extent(%153 : i64) stride(%155 : i64) start_idx(%151 : i64) {stride_in_bytes = true}
+ %158 = llvm.getelementptr %19[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %159 = omp.map.info var_ptr(%19 : !llvm.ptr, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) var_ptr_ptr(%158 : !llvm.ptr) bounds(%157) -> !llvm.ptr {name = ""}
+ %160 = omp.map.info var_ptr(%19 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always, descriptor, to) capture(ByRef) members(%159 : [0] : !llvm.ptr) -> !llvm.ptr
+ %1501 = llvm.getelementptr %20[0, 7, %1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %1511 = llvm.load %1501 : !llvm.ptr -> i64
+ %1521 = llvm.getelementptr %20[0, 7, %1, 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %1531 = llvm.load %1521 : !llvm.ptr -> i64
+ %1541 = llvm.getelementptr %20[0, 7, %1, 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %1551 = llvm.load %1541 : !llvm.ptr -> i64
+ %1561 = llvm.sub %1531, %1 : i64
+ %1571 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%1561 : i64) extent(%1531 : i64) stride(%1551 : i64) start_idx(%1511 : i64) {stride_in_bytes = true}
+ %1581 = llvm.getelementptr %20[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %1591 = omp.map.info var_ptr(%20 : !llvm.ptr, i32) map_clauses(descriptor_base_addr, to) capture(ByRef) var_ptr_ptr(%1581 : !llvm.ptr) bounds(%1571) -> !llvm.ptr {name = ""}
+ %1601 = omp.map.info var_ptr(%20 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always, descriptor, to) capture(ByRef) members(%1591 : [0] : !llvm.ptr) -> !llvm.ptr
+
+ // Test with two firstprivate variables so that we test that even if there are multiple variables to be cleaned up
+ // only one cleanup omp.task is generated.
+ omp.target nowait map_entries(%124 -> %arg2, %160 -> %arg5, %159 -> %arg8, %1601 -> %arg9, %1591 -> %arg10 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@firstprivatizer %19 -> %arg11 [map_idx=1], @firstprivatizer_1 %20 -> %arg12 [map_idx=3] : !llvm.ptr, !llvm.ptr) {
+ omp.terminator
+ }
+ %166 = llvm.mlir.constant(48 : i32) : i32
+ %167 = llvm.getelementptr %19[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+ %168 = llvm.load %167 : !llvm.ptr -> !llvm.ptr
+ llvm.call @free(%168) : (!llvm.ptr) -> ()
+ llvm.return
+ }
+
+}
+// CHECK-LABEL: llvm.func @free(!llvm.ptr)
+// CHECK: llvm.func @malloc(i64) -> !llvm.ptr
+
+
+// CHECK-LABEL: llvm.func internal @firstprivate_test(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr {fir.bindc_name = "ptr0"},
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr {fir.bindc_name = "ptr1"}) {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[VAL_2:.*]] = llvm.alloca %[[VAL_0]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(48 : i64) : i64
+// CHECK: %[[HEAP0:.*]] = llvm.call @malloc(%[[VAL_3]]) : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_0]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "local"} : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(48 : i64) : i64
+// CHECK: %[[HEAP1:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_8:.*]] = llvm.alloca %[[VAL_0]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> {bindc_name = "glocal"} : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_9:.*]] = llvm.alloca %[[VAL_0]] x i32 {bindc_name = "i"} : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: llvm.store %[[VAL_10]], %[[VAL_5]] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
+// CHECK: llvm.store %[[VAL_10]], %[[VAL_8]] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, !llvm.ptr
+// CHECK: llvm.store %[[VAL_0]], %[[VAL_9]] : i32, !llvm.ptr
+// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_9]] : !llvm.ptr, i32) map_clauses(implicit) capture(ByCopy) -> !llvm.ptr {name = "i"}
+// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_5]][0, 7, %[[VAL_1]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_13:.*]] = llvm.load %[[VAL_12]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_14:.*]] = llvm.getelementptr %[[VAL_5]][0, 7, %[[VAL_1]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_15:.*]] = llvm.load %[[VAL_14]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_16:.*]] = llvm.getelementptr %[[VAL_5]][0, 7, %[[VAL_1]], 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_17:.*]] = llvm.load %[[VAL_16]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_18:.*]] = llvm.sub %[[VAL_15]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_19:.*]] = omp.map.bounds lower_bound(%[[VAL_1]] : i64) upper_bound(%[[VAL_18]] : i64) extent(%[[VAL_15]] : i64) stride(%[[VAL_17]] : i64) start_idx(%[[VAL_13]] : i64) {stride_in_bytes = true}
+// CHECK: %[[VAL_20:.*]] = llvm.call @firstprivatizer_init(%[[VAL_5]], %[[HEAP0]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_21:.*]] = llvm.call @firstprivatizer_copy(%[[VAL_5]], %[[VAL_20]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_22:.*]] = llvm.getelementptr %[[VAL_8]][0, 7, %[[VAL_1]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_23:.*]] = llvm.load %[[VAL_22]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_24:.*]] = llvm.getelementptr %[[VAL_8]][0, 7, %[[VAL_1]], 1] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_25:.*]] = llvm.load %[[VAL_24]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_26:.*]] = llvm.getelementptr %[[VAL_8]][0, 7, %[[VAL_1]], 2] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_27:.*]] = llvm.load %[[VAL_26]] : !llvm.ptr -> i64
+// CHECK: %[[VAL_28:.*]] = llvm.sub %[[VAL_25]], %[[VAL_1]] : i64
+// CHECK: %[[VAL_29:.*]] = omp.map.bounds lower_bound(%[[VAL_1]] : i64) upper_bound(%[[VAL_28]] : i64) extent(%[[VAL_25]] : i64) stride(%[[VAL_27]] : i64) start_idx(%[[VAL_23]] : i64) {stride_in_bytes = true}
+// CHECK: %[[VAL_30:.*]] = llvm.call @firstprivatizer_1_init(%[[VAL_8]], %[[HEAP1]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_31:.*]] = llvm.call @firstprivatizer_1_copy(%[[VAL_8]], %[[VAL_30]]) : (!llvm.ptr, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_32:.*]] = llvm.getelementptr %[[HEAP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_33:.*]] = omp.map.info var_ptr(%[[HEAP0]] : !llvm.ptr, i32) map_clauses({{.*}}to{{.*}}) capture(ByRef) var_ptr_ptr(%[[VAL_32]] : !llvm.ptr) bounds(%[[VAL_19]]) -> !llvm.ptr {name = ""}
+// CHECK: %[[VAL_34:.*]] = omp.map.info var_ptr(%[[HEAP0]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always,{{.*}}to) capture(ByRef) members(%[[VAL_33]] : [0] : !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[HEAP1]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_36:.*]] = omp.map.info var_ptr(%[[HEAP1]] : !llvm.ptr, i32) map_clauses({{.*}}to{{.*}}) capture(ByRef) var_ptr_ptr(%[[VAL_35]] : !llvm.ptr) bounds(%[[VAL_29]]) -> !llvm.ptr {name = ""}
+// CHECK: %[[VAL_37:.*]] = omp.map.info var_ptr(%[[HEAP1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(always,{{.*}}to) capture(ByRef) members(%[[VAL_36]] : [0] : !llvm.ptr) -> !llvm.ptr
+// CHECK: omp.target depend(taskdependout -> %[[HEAP0]] : !llvm.ptr) nowait map_entries(%[[VAL_11]] -> %[[VAL_38:.*]], %[[VAL_34]] -> %[[VAL_39:.*]], %[[VAL_33]] -> %[[VAL_40:.*]], %[[VAL_37]] -> %[[VAL_41:.*]], %[[VAL_36]] -> %[[VAL_42:.*]] : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) private(@firstprivatizer %[[HEAP0]] -> %[[VAL_43:.*]] [map_idx=1], @firstprivatizer_1 %[[HEAP1]] -> %[[VAL_44:.*]] [map_idx=3] : !llvm.ptr, !llvm.ptr) {
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.task depend(taskdependin -> %[[HEAP0]] : !llvm.ptr) {
+// CHECK: llvm.call @firstprivatizer_1_dealloc(%[[VAL_31]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @free(%[[HEAP1]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @firstprivatizer_dealloc(%[[VAL_21]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.call @free(%[[HEAP0]]) : (!llvm.ptr) -> ()
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_45:.*]] = llvm.mlir.constant(48 : i32) : i32
+// CHECK: %[[VAL_46:.*]] = llvm.getelementptr %[[VAL_5]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: %[[VAL_47:.*]] = llvm.load %[[VAL_46]] : !llvm.ptr -> !llvm.ptr
+// CHECK: llvm.call @free(%[[VAL_47]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.return
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_init(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(48 : i64) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.call @malloc(%[[VAL_0]]) : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[ARG1]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: llvm.store %[[VAL_1]], %[[VAL_2]] : !llvm.ptr, !llvm.ptr
+// CHECK: llvm.return %[[ARG1]] : !llvm.ptr
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_copy(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(48 : i32) : i32
+// CHECK: "llvm.intr.memcpy"(%[[ARG1]], %[[ARG0]], %[[VAL_0]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK: llvm.return %[[ARG1]] : !llvm.ptr
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_dealloc(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) attributes {always_inline} {
+// CHECK: llvm.call @free(%[[ARG0]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.return
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_1_init(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(48 : i64) : i64
+// CHECK: %[[VAL_1:.*]] = llvm.call @malloc(%[[VAL_0]]) : (i64) -> !llvm.ptr
+// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[ARG1]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
+// CHECK: llvm.store %[[VAL_1]], %[[VAL_2]] : !llvm.ptr, !llvm.ptr
+// CHECK: llvm.return %[[ARG1]] : !llvm.ptr
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_1_copy(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr,
+// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline} {
+// CHECK: %[[VAL_0:.*]] = llvm.mlir.constant(48 : i32) : i32
+// CHECK: "llvm.intr.memcpy"(%[[ARG1]], %[[ARG0]], %[[VAL_0]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK: llvm.return %[[ARG1]] : !llvm.ptr
+// CHECK: }
+
+// CHECK-LABEL: llvm.func @firstprivatizer_1_dealloc(
+// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) attributes {always_inline} {
+// CHECK: llvm.call @free(%[[ARG0]]) : (!llvm.ptr) -> ()
+// CHECK: llvm.return
+// CHECK: }
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index e8525a5..7574afa 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<1xi32> {
// -----
+// CHECK-LABEL: @test_argmax_fold_i64_index
+func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor<i64> {
+ // CHECK: "tosa.const"() <{values = dense<0> : tensor<i64>}> : () -> tensor<i64>
+ %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor<i64>
+ return %0 : tensor<i64>
+}
+
+// -----
+
// CHECK-LABEL: @pad_wh_avg_pool2d_fold
func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> {
// CHECK-NOT: tosa.pad
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 0e1365a..27a3dc3 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -214,3 +214,54 @@ gpu.module @xevm_module{
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @warp_scf_for_unused_uniform_for_result(
+// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] args(%{{.*}} : index,
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME: memref<16x16xf32>) -> (vector<16x1xf32>, vector<16x1xf32>) {
+// CHECK: gpu.yield %{{.*}}, {{.*}} : vector<16x16xf32>, vector<16x1xf32>
+// CHECK: }
+// CHECK: %{{.*}}:2 = scf.for {{.*}} to %{{.*}} step %{{.*}} iter_args
+// CHECK-SAME: (%{{.*}} = %[[W]]#0, %{{.*}} = %[[W]]#1) -> (vector<16x1xf32>, vector<16x1xf32>) {
+// CHECK: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
+// CHECK-SAME: args(%{{.*}} : vector<16x1xf32>, vector<16x1xf32>) -> (vector<16x1xf32>, vector<16x1xf32>) {
+// CHECK: gpu.yield %{{.*}}, %{{.*}} : vector<16x16xf32>, vector<16x1xf32>
+// CHECK: }
+// CHECK: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<16x1xf32>, vector<16x1xf32>
+// CHECK: }
+gpu.module @xevm_module{
+ gpu.func @warp_scf_for_unused_uniform_for_result(%arg0: index,
+ %arg1: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+ %arg2: memref<16x16xf32>) {
+ %c128 = arith.constant 128 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %ini = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<16x1xf32>)
+ %ini2 = "some_def"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<16x16xf32>)
+ %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini2, %arg5 = %ini) -> (vector<16x16xf32>, vector<16x1xf32>) {
+ %1 = "some_def"(%arg5)
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : (vector<16x1xf32>) -> (vector<16x1xf32>)
+ %acc = "some_def"(%arg4, %1)
+ {
+ layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : (vector<16x16xf32>, vector<16x1xf32>) -> (vector<16x16xf32>)
+ scf.yield %acc, %1 : vector<16x16xf32>, vector<16x1xf32>
+ }
+ {
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ xegpu.store_nd %3#0, %arg1[%c0, %c0]
+ : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 9f4393e..127ab70 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,6 +103,17 @@ func.func @main() {
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %c0x = arith.constant dense<1.0> : tensor<0xf32>
+ %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
+ %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32>
+ %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
return
}
@@ -297,3 +308,15 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+
+func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+func.func @fill_empty_2d(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index 04e2ddf..451475c 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -10,3 +10,14 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
%res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
llvm.return
}
+
+// CHECK-LABEL: @convert_f4x2_to_f16x2
+llvm.func @convert_f4x2_to_f16x2(%src : i8) {
+ // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 9928992..61a7a48 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
+llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
+llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index de21826..4afe901 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2
+llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
+llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
+llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 6cccfe4..09b8f59 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -262,6 +262,38 @@ llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
// -----
+llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+ %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+ %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+ %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 2fa4470..af6d254 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -249,24 +249,6 @@ llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
// -----
-omp.private {type = firstprivate} @x.privatizer : i32 copy {
-^bb0(%mold: !llvm.ptr, %private: !llvm.ptr):
- %0 = llvm.load %mold : !llvm.ptr -> i32
- llvm.store %0, %private : i32, !llvm.ptr
- omp.yield(%private: !llvm.ptr)
-}
-llvm.func @target_firstprivate(%x : !llvm.ptr) {
- %0 = omp.map.info var_ptr(%x : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr
- // expected-error@below {{not yet implemented: Unhandled clause privatization for deferred target tasks in omp.target operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.target}}
- omp.target nowait map_entries(%0 -> %blockarg0 : !llvm.ptr) private(@x.privatizer %x -> %arg0 [map_idx=0] : !llvm.ptr) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}
diff --git a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
index 1e59338..a54b642 100644
--- a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_library(MLIROpenACCTestPasses
TestOpenACC.cpp
TestPointerLikeTypeInterface.cpp
TestRecipePopulate.cpp
+ TestOpenACCSupport.cpp
EXCLUDE_FROM_LIBMLIR
)
@@ -11,6 +12,7 @@ mlir_target_link_libraries(MLIROpenACCTestPasses PUBLIC
MLIRFuncDialect
MLIRMemRefDialect
MLIROpenACCDialect
+ MLIROpenACCAnalysis
MLIRPass
MLIRSupport
)
diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
index bea21b9..e59d777 100644
--- a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
+++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
@@ -16,11 +16,13 @@ namespace test {
// Forward declarations of individual test pass registration functions
void registerTestPointerLikeTypeInterfacePass();
void registerTestRecipePopulatePass();
+void registerTestOpenACCSupportPass();
// Unified registration function for all OpenACC tests
void registerTestOpenACC() {
registerTestPointerLikeTypeInterfacePass();
registerTestRecipePopulatePass();
+ registerTestOpenACCSupportPass();
}
} // namespace test
diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp
new file mode 100644
index 0000000..8bf984b
--- /dev/null
+++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACCSupport.cpp
@@ -0,0 +1,73 @@
+//===- TestOpenACCSupport.cpp - Test OpenACCSupport Analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for testing the OpenACCSupport analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+struct TestOpenACCSupportPass
+ : public PassWrapper<TestOpenACCSupportPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpenACCSupportPass)
+
+ StringRef getArgument() const override { return "test-acc-support"; }
+
+ StringRef getDescription() const override {
+ return "Test OpenACCSupport analysis";
+ }
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<acc::OpenACCDialect>();
+ registry.insert<memref::MemRefDialect>();
+ }
+};
+
+void TestOpenACCSupportPass::runOnOperation() {
+ auto func = getOperation();
+
+ // Get the OpenACCSupport analysis
+ OpenACCSupport &support = getAnalysis<OpenACCSupport>();
+
+ // Walk through operations looking for test attributes
+ func.walk([&](Operation *op) {
+ // Check for test.var_name attribute. This is the marker used to identify
+ // the operations that need to be tested for getVariableName.
+ if (op->hasAttr("test.var_name")) {
+ // For each result of this operation, try to get the variable name
+ for (auto result : op->getResults()) {
+ std::string foundName = support.getVariableName(result);
+ llvm::outs() << "op=" << *op << "\n\tgetVariableName=\"" << foundName
+ << "\"\n";
+ }
+ }
+ });
+}
+
+} // namespace
+
+namespace mlir {
+namespace test {
+
+void registerTestOpenACCSupportPass() {
+ PassRegistration<TestOpenACCSupportPass>();
+}
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ee4fa39..efbdbfb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -2136,7 +2136,7 @@ struct TestTypeConversionDriver
Location loc) -> Value {
if (inputs.size() != 1 || !inputs[0].getType().isInteger(37))
return Value();
- return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+ return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
.getResult(0);
});
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
index ab817b6..3fbbcc9 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -410,3 +410,78 @@ TEST_F(OpenACCUtilsTest, getTypeCategoryArray) {
VariableTypeCategory category = getTypeCategory(varPtr);
EXPECT_EQ(category, VariableTypeCategory::array);
}
+
+//===----------------------------------------------------------------------===//
+// getVariableName Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCUtilsTest, getVariableNameDirect) {
+ // Create a memref with acc.var_name attribute
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+
+ // Set the acc.var_name attribute
+ auto varNameAttr = VarNameAttr::get(&context, "my_variable");
+ allocOp.get()->setAttr(getVarNameAttrName(), varNameAttr);
+
+ Value varPtr = allocOp->getResult();
+
+ // Test that getVariableName returns the variable name
+ std::string varName = getVariableName(varPtr);
+ EXPECT_EQ(varName, "my_variable");
+}
+
+TEST_F(OpenACCUtilsTest, getVariableNameThroughCast) {
+ // Create a 5x2 memref with acc.var_name attribute
+ auto memrefTy = MemRefType::get({5, 2}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+
+ // Set the acc.var_name attribute on the alloca
+ auto varNameAttr = VarNameAttr::get(&context, "casted_variable");
+ allocOp.get()->setAttr(getVarNameAttrName(), varNameAttr);
+
+ Value allocResult = allocOp->getResult();
+
+ // Create a memref.cast operation to a flattened 10-element array
+ auto castedMemrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::CastOp> castOp =
+ memref::CastOp::create(b, loc, castedMemrefTy, allocResult);
+
+ Value castedPtr = castOp->getResult();
+
+ // Test that getVariableName walks through the cast to find the variable name
+ std::string varName = getVariableName(castedPtr);
+ EXPECT_EQ(varName, "casted_variable");
+}
+
+TEST_F(OpenACCUtilsTest, getVariableNameNotFound) {
+ // Create a memref without acc.var_name attribute
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+
+ Value varPtr = allocOp->getResult();
+
+ // Test that getVariableName returns empty string when no name is found
+ std::string varName = getVariableName(varPtr);
+ EXPECT_EQ(varName, "");
+}
+
+TEST_F(OpenACCUtilsTest, getVariableNameFromCopyin) {
+ // Create a memref
+ auto memrefTy = MemRefType::get({10}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> allocOp =
+ memref::AllocaOp::create(b, loc, memrefTy);
+
+ Value varPtr = allocOp->getResult();
+ StringRef name = "data_array";
+ OwningOpRef<CopyinOp> copyinOp =
+ CopyinOp::create(b, loc, varPtr, /*structured=*/true, /*implicit=*/true,
+ /*name=*/name);
+
+ // Test that getVariableName extracts the name from the copyin operation
+ std::string varName = getVariableName(copyinOp->getAccVar());
+ EXPECT_EQ(varName, name);
+}