diff options
Diffstat (limited to 'mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp')
| -rw-r--r-- | mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp | 189 | 
1 files changed, 189 insertions, 0 deletions
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp index 3fbbcc9..6f4e305 100644 --- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp +++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp @@ -485,3 +485,192 @@ TEST_F(OpenACCUtilsTest, getVariableNameFromCopyin) {    std::string varName = getVariableName(copyinOp->getAccVar());    EXPECT_EQ(varName, name);  } + +//===----------------------------------------------------------------------===// +// getRecipeName Tests +//===----------------------------------------------------------------------===// + +TEST_F(OpenACCUtilsTest, getRecipeNamePrivateScalarMemref) { +  // Create a scalar memref type +  auto scalarMemrefTy = MemRefType::get({}, b.getI32Type()); + +  // Test private recipe with scalar memref +  std::string recipeName = +      getRecipeName(RecipeKind::private_recipe, scalarMemrefTy); +  EXPECT_EQ(recipeName, "privatization_memref_i32_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNameFirstprivateScalarMemref) { +  // Create a scalar memref type +  auto scalarMemrefTy = MemRefType::get({}, b.getF32Type()); + +  // Test firstprivate recipe with scalar memref +  std::string recipeName = +      getRecipeName(RecipeKind::firstprivate_recipe, scalarMemrefTy); +  EXPECT_EQ(recipeName, "firstprivatization_memref_f32_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNameReductionScalarMemref) { +  // Create a scalar memref type +  auto scalarMemrefTy = MemRefType::get({}, b.getI64Type()); + +  // Test reduction recipe with scalar memref +  std::string recipeName = +      getRecipeName(RecipeKind::reduction_recipe, scalarMemrefTy); +  EXPECT_EQ(recipeName, "reduction_memref_i64_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNamePrivate2DMemref) { +  // Create a 2D memref type +  auto memref2DTy = MemRefType::get({5, 10}, b.getF32Type()); + +  // Test private recipe with 2D memref +  std::string recipeName = +      getRecipeName(RecipeKind::private_recipe, memref2DTy); +  EXPECT_EQ(recipeName, "privatization_memref_5x10xf32_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNameFirstprivate2DMemref) { +  // Create a 2D memref type +  auto memref2DTy = MemRefType::get({8, 16}, b.getF64Type()); + +  // Test firstprivate recipe with 2D memref +  std::string recipeName = +      getRecipeName(RecipeKind::firstprivate_recipe, memref2DTy); +  EXPECT_EQ(recipeName, "firstprivatization_memref_8x16xf64_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNameReduction2DMemref) { +  // Create a 2D memref type +  auto memref2DTy = MemRefType::get({4, 8}, b.getI32Type()); + +  // Test reduction recipe with 2D memref +  std::string recipeName = +      getRecipeName(RecipeKind::reduction_recipe, memref2DTy); +  EXPECT_EQ(recipeName, "reduction_memref_4x8xi32_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNamePrivateDynamicMemref) { +  // Create a memref with dynamic dimensions +  auto dynamicMemrefTy = +      MemRefType::get({ShapedType::kDynamic, 10}, b.getI32Type()); + +  // Test private recipe with dynamic memref +  std::string recipeName = +      getRecipeName(RecipeKind::private_recipe, dynamicMemrefTy); +  EXPECT_EQ(recipeName, "privatization_memref_Ux10xi32_"); +} + +TEST_F(OpenACCUtilsTest, getRecipeNamePrivateUnrankedMemref) { +  // Create an unranked memref type +  auto unrankedMemrefTy = UnrankedMemRefType::get(b.getI32Type(), 0); + +  // Test private recipe with unranked memref +  std::string recipeName = +      getRecipeName(RecipeKind::private_recipe, unrankedMemrefTy); +  EXPECT_EQ(recipeName, "privatization_memref_Zxi32_"); +} + +//===----------------------------------------------------------------------===// +// getBaseEntity Tests +//===----------------------------------------------------------------------===// + +// Local implementation of PartialEntityAccessOpInterface for memref.subview. +// This is implemented locally in the test rather than officially because memref +// operations already have ViewLikeOpInterface, which serves a similar purpose +// for walking through views to the base entity. This test demonstrates how +// getBaseEntity() would work if the interface were attached to memref.subview. +namespace { +struct SubViewOpPartialEntityAccessOpInterface +    : public acc::PartialEntityAccessOpInterface::ExternalModel< +          SubViewOpPartialEntityAccessOpInterface, memref::SubViewOp> { +  Value getBaseEntity(Operation *op) const { +    auto subviewOp = cast<memref::SubViewOp>(op); +    return subviewOp.getSource(); +  } + +  bool isCompleteView(Operation *op) const { +    // For testing purposes, we'll consider it a partial view (return false). +    // The real implementation would need to look at the offsets. +    return false; +  } +}; +} // namespace + +TEST_F(OpenACCUtilsTest, getBaseEntityFromSubview) { +  // Register the local interface implementation for memref.subview +  memref::SubViewOp::attachInterface<SubViewOpPartialEntityAccessOpInterface>( +      context); + +  // Create a base memref +  auto memrefTy = MemRefType::get({10, 20}, b.getF32Type()); +  OwningOpRef<memref::AllocaOp> allocOp = +      memref::AllocaOp::create(b, loc, memrefTy); +  Value baseMemref = allocOp->getResult(); + +  // Create a subview of the base memref with non-zero offsets +  // This creates a 5x10 view starting at [2, 3] in the original 10x20 memref +  SmallVector<OpFoldResult> offsets = {b.getIndexAttr(2), b.getIndexAttr(3)}; +  SmallVector<OpFoldResult> sizes = {b.getIndexAttr(5), b.getIndexAttr(10)}; +  SmallVector<OpFoldResult> strides = {b.getIndexAttr(1), b.getIndexAttr(1)}; + +  OwningOpRef<memref::SubViewOp> subviewOp = +      memref::SubViewOp::create(b, loc, baseMemref, offsets, sizes, strides); +  Value subview = subviewOp->getResult(); + +  // Test that getBaseEntity returns the base memref, not the subview +  Value baseEntity = getBaseEntity(subview); +  EXPECT_EQ(baseEntity, baseMemref); +} + +TEST_F(OpenACCUtilsTest, getBaseEntityNoInterface) { +  // Create a memref without the interface +  auto memrefTy = MemRefType::get({10}, b.getI32Type()); +  OwningOpRef<memref::AllocaOp> allocOp = +      memref::AllocaOp::create(b, loc, memrefTy); +  Value varPtr = allocOp->getResult(); + +  // Test that getBaseEntity returns the value itself when there's no interface +  Value baseEntity = getBaseEntity(varPtr); +  EXPECT_EQ(baseEntity, varPtr); +} + +TEST_F(OpenACCUtilsTest, getBaseEntityChainedSubviews) { +  // Register the local interface implementation for memref.subview +  memref::SubViewOp::attachInterface<SubViewOpPartialEntityAccessOpInterface>( +      context); + +  // Create a base memref +  auto memrefTy = MemRefType::get({100, 200}, b.getI64Type()); +  OwningOpRef<memref::AllocaOp> allocOp = +      memref::AllocaOp::create(b, loc, memrefTy); +  Value baseMemref = allocOp->getResult(); + +  // Create first subview +  SmallVector<OpFoldResult> offsets1 = {b.getIndexAttr(10), b.getIndexAttr(20)}; +  SmallVector<OpFoldResult> sizes1 = {b.getIndexAttr(50), b.getIndexAttr(80)}; +  SmallVector<OpFoldResult> strides1 = {b.getIndexAttr(1), b.getIndexAttr(1)}; + +  OwningOpRef<memref::SubViewOp> subview1Op = +      memref::SubViewOp::create(b, loc, baseMemref, offsets1, sizes1, strides1); +  Value subview1 = subview1Op->getResult(); + +  // Create second subview (subview of subview) +  SmallVector<OpFoldResult> offsets2 = {b.getIndexAttr(5), b.getIndexAttr(10)}; +  SmallVector<OpFoldResult> sizes2 = {b.getIndexAttr(20), b.getIndexAttr(30)}; +  SmallVector<OpFoldResult> strides2 = {b.getIndexAttr(1), b.getIndexAttr(1)}; + +  OwningOpRef<memref::SubViewOp> subview2Op = +      memref::SubViewOp::create(b, loc, subview1, offsets2, sizes2, strides2); +  Value subview2 = subview2Op->getResult(); + +  // Test that getBaseEntity on the nested subview returns the first subview +  // (since our implementation returns the immediate source, not the ultimate +  // base) +  Value baseEntity = getBaseEntity(subview2); +  EXPECT_EQ(baseEntity, subview1); + +  // Test that calling getBaseEntity again returns the original base +  Value ultimateBase = getBaseEntity(baseEntity); +  EXPECT_EQ(ultimateBase, baseMemref); +}  | 
