aboutsummaryrefslogtreecommitdiff
path: root/mlir/unittests
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/unittests')
-rw-r--r--mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp85
-rw-r--r--mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp38
2 files changed, 116 insertions, 7 deletions
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
index 3fbbcc9..f1fe53c 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -485,3 +485,88 @@ TEST_F(OpenACCUtilsTest, getVariableNameFromCopyin) {
std::string varName = getVariableName(copyinOp->getAccVar());
EXPECT_EQ(varName, name);
}
+
+//===----------------------------------------------------------------------===//
+// getRecipeName Tests
+//===----------------------------------------------------------------------===//
+
+TEST_F(OpenACCUtilsTest, getRecipeNamePrivateScalarMemref) {
+ // Create a scalar memref type
+ auto scalarMemrefTy = MemRefType::get({}, b.getI32Type());
+
+ // Test private recipe with scalar memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::private_recipe, scalarMemrefTy);
+ EXPECT_EQ(recipeName, "privatization_memref_i32_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNameFirstprivateScalarMemref) {
+ // Create a scalar memref type
+ auto scalarMemrefTy = MemRefType::get({}, b.getF32Type());
+
+ // Test firstprivate recipe with scalar memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::firstprivate_recipe, scalarMemrefTy);
+ EXPECT_EQ(recipeName, "firstprivatization_memref_f32_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNameReductionScalarMemref) {
+ // Create a scalar memref type
+ auto scalarMemrefTy = MemRefType::get({}, b.getI64Type());
+
+ // Test reduction recipe with scalar memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::reduction_recipe, scalarMemrefTy);
+ EXPECT_EQ(recipeName, "reduction_memref_i64_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNamePrivate2DMemref) {
+ // Create a 2D memref type
+ auto memref2DTy = MemRefType::get({5, 10}, b.getF32Type());
+
+ // Test private recipe with 2D memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::private_recipe, memref2DTy);
+ EXPECT_EQ(recipeName, "privatization_memref_5x10xf32_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNameFirstprivate2DMemref) {
+ // Create a 2D memref type
+ auto memref2DTy = MemRefType::get({8, 16}, b.getF64Type());
+
+ // Test firstprivate recipe with 2D memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::firstprivate_recipe, memref2DTy);
+ EXPECT_EQ(recipeName, "firstprivatization_memref_8x16xf64_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNameReduction2DMemref) {
+ // Create a 2D memref type
+ auto memref2DTy = MemRefType::get({4, 8}, b.getI32Type());
+
+ // Test reduction recipe with 2D memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::reduction_recipe, memref2DTy);
+ EXPECT_EQ(recipeName, "reduction_memref_4x8xi32_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNamePrivateDynamicMemref) {
+ // Create a memref with dynamic dimensions
+ auto dynamicMemrefTy =
+ MemRefType::get({ShapedType::kDynamic, 10}, b.getI32Type());
+
+ // Test private recipe with dynamic memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::private_recipe, dynamicMemrefTy);
+ EXPECT_EQ(recipeName, "privatization_memref_Ux10xi32_");
+}
+
+TEST_F(OpenACCUtilsTest, getRecipeNamePrivateUnrankedMemref) {
+ // Create an unranked memref type
+ auto unrankedMemrefTy = UnrankedMemRefType::get(b.getI32Type(), 0);
+
+ // Test private recipe with unranked memref
+ std::string recipeName =
+ getRecipeName(RecipeKind::private_recipe, unrankedMemrefTy);
+ EXPECT_EQ(recipeName, "privatization_memref_Zxi32_");
+}
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index f1aae15..2e6950f 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -13,17 +13,24 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/Support/DebugLog.h"
#include <gtest/gtest.h>
using namespace mlir;
/// A dummy op that is also a terminator.
-struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
+struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator, OpTrait::ZeroResults,
+ OpTrait::ZeroSuccessors,
+ RegionBranchTerminatorOpInterface::Trait> {
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() { return "cftest.dummy_op"; }
+
+ MutableOperandRange getMutableSuccessorOperands(RegionSuccessor point) {
+ return MutableOperandRange(getOperation(), 0, 0);
+ }
};
/// All regions of this op are mutually exclusive.
@@ -39,6 +46,8 @@ struct MutuallyExclusiveRegionsOp
// Regions have no successors.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {}
+ using RegionBranchOpInterface::Trait<
+ MutuallyExclusiveRegionsOp>::getSuccessorRegions;
};
/// All regions of this op call each other in a large circle.
@@ -53,13 +62,18 @@ struct LoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (Region *region = point.getRegionOrNull()) {
- if (point == (*this)->getRegion(1))
+ if (point.getTerminatorPredecessorOrNull()) {
+ Region *region =
+ point.getTerminatorPredecessorOrNull()->getParentRegion();
+ if (region == &(*this)->getRegion(1))
// This region also branches back to the parent.
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation()->getParentOp(),
+ getOperation()->getParentOp()->getResults()));
regions.push_back(RegionSuccessor(region));
}
}
+ using RegionBranchOpInterface::Trait<LoopRegionsOp>::getSuccessorRegions;
};
/// Each region branches back it itself or the parent.
@@ -75,11 +89,17 @@ struct DoubleLoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (Region *region = point.getRegionOrNull()) {
- regions.push_back(RegionSuccessor());
+ if (point.getTerminatorPredecessorOrNull()) {
+ Region *region =
+ point.getTerminatorPredecessorOrNull()->getParentRegion();
+ regions.push_back(
+ RegionSuccessor(getOperation()->getParentOp(),
+ getOperation()->getParentOp()->getResults()));
regions.push_back(RegionSuccessor(region));
}
}
+ using RegionBranchOpInterface::Trait<
+ DoubleLoopRegionsOp>::getSuccessorRegions;
};
/// Regions are executed sequentially.
@@ -93,11 +113,15 @@ struct SequentialRegionsOp
// Region 0 has Region 1 as a successor.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (point == (*this)->getRegion(0)) {
+ if (point.getTerminatorPredecessorOrNull() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &(*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
}
+ using RegionBranchOpInterface::Trait<
+ SequentialRegionsOp>::getSuccessorRegions;
};
/// A dialect putting all the above together.