aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/test/EDSC/builder-api-test.cpp106
1 files changed, 33 insertions, 73 deletions
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 50409434..2ad3f0c 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -344,7 +344,10 @@ TEST_FUNC(builder_helpers) {
using namespace edsc::intrinsics;
using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1, -1}, f32Type, {}, 0);
+ auto memrefType =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
+ ShapedType::kDynamicSize},
+ f32Type, {}, 0);
auto f =
makeFunction("builder_helpers", {}, {memrefType, memrefType, memrefType});
@@ -491,7 +494,8 @@ TEST_FUNC(select_op_i32) {
using namespace edsc::intrinsics;
using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType});
OpBuilder builder(f.getBody());
@@ -526,7 +530,8 @@ TEST_FUNC(select_op_f32) {
using namespace edsc::intrinsics;
using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType, memrefType});
OpBuilder builder(f.getBody());
@@ -603,7 +608,9 @@ TEST_FUNC(tile_2d) {
using namespace edsc::intrinsics;
using namespace edsc::op;
auto memrefType =
- MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
+ ShapedType::kDynamicSize},
+ FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("tile_2d", {}, {memrefType, memrefType, memrefType});
OpBuilder builder(f.getBody());
@@ -666,73 +673,13 @@ TEST_FUNC(tile_2d) {
f.erase();
}
-// Inject an EDSC-constructed computation to exercise 2-d vectorization.
-// TODO(ntv,andydavis) Convert EDSC to use AffineLoad/Store.
-/*
-TEST_FUNC(vectorize_2d) {
- using namespace edsc;
- using namespace edsc::intrinsics;
- using namespace edsc::op;
- auto memrefType =
- MemRefType::get({-1, -1, -1}, FloatType::getF32(&globalContext()), {}, 0);
- auto owningF =
- makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
-
- mlir::FuncOp f = owningF;
- mlir::OwningModuleRef module = ModuleOp::create(&globalContext());
- module->push_back(f);
-
- OpBuilder builder(f.getBody());
- ScopedContext scope(builder, f.getLoc());
- ValueHandle zero = constant_index(0);
- MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2));
- IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2));
- IndexHandle M(vA.ub(0)), N(vA.ub(1)), P(vA.ub(2));
-
- // clang-format off
- IndexHandle i, j, k;
- AffineLoopNestBuilder({&i, &j, &k}, {zero, zero, zero}, {M, N, P}, {1, 1,
-1})([&]{ C(i, j, k) = A(i, j, k) + B(i, j, k);
- });
- ret();
-
- // xCHECK-LABEL: func @vectorize_2d
- // xCHECK-NEXT: %[[M:.*]] = dim %{{.*}}, 0 : memref<?x?x?xf32>
- // xCHECK-NEXT: %[[N:.*]] = dim %{{.*}}, 1 : memref<?x?x?xf32>
- // xCHECK-NEXT: %[[P:.*]] = dim %{{.*}}, 2 : memref<?x?x?xf32>
- // xCHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[M]]) {
- // xCHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[N]]) step 4 {
- // xCHECK-NEXT: affine.for %{{.*}} = 0 to (d0) -> (d0)(%[[P]]) step 4 {
- // xCHECK-NEXT: %[[vA:.*]] = "vector.transfer_read"(%{{.*}}, %{{.*}},
-%{{.*}}, %i2) {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} :
-(memref<?x?x?xf32>, index, index, index) -> vector<4x4xf32>
- // xCHECK-NEXT: %[[vB:.*]] = "vector.transfer_read"(%{{.*}}, %{{.*}},
-%{{.*}}, %i2) {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} :
-(memref<?x?x?xf32>, index, index, index) -> vector<4x4xf32>
- // xCHECK-NEXT: %[[vRES:.*]] = addf %[[vB]], %[[vA]] : vector<4x4xf32>
- // xCHECK-NEXT: "vector.transfer_write"(%[[vRES:.*]], %{{.*}}, %{{.*}},
-%{{.*}}, %i2) {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} :
-(vector<4x4xf32>, memref<?x?x?xf32>, index, index, index) -> ()
- // clang-format on
-
- mlir::PassManager pm;
- pm.addPass(mlir::createCanonicalizerPass());
- SmallVector<int64_t, 2> vectorSizes{4, 4};
- pm.addPass(mlir::createVectorizePass(vectorSizes));
- auto result = pm.run(f.getModule());
- if (succeeded(result))
- f.print(llvm::outs());
- f.erase();
-}
-*/
-
// Exercise StdIndexedValue for loads and stores.
TEST_FUNC(indirect_access) {
using namespace edsc;
using namespace edsc::intrinsics;
using namespace edsc::op;
- auto memrefType =
- MemRefType::get({-1}, FloatType::getF32(&globalContext()), {}, 0);
+ auto memrefType = MemRefType::get({ShapedType::kDynamicSize},
+ FloatType::getF32(&globalContext()), {}, 0);
auto f = makeFunction("indirect_access", {},
{memrefType, memrefType, memrefType, memrefType});
@@ -794,17 +741,20 @@ TEST_FUNC(empty_map_load_store) {
f.erase();
}
+// clang-format off
// CHECK-LABEL: func @affine_if_op
// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
// CHECK-NOT: else
// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}}
// CHECK-NEXT: } else {
+// clang-format on
TEST_FUNC(affine_if_op) {
using namespace edsc;
using namespace edsc::intrinsics;
using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f = makeFunction("affine_if_op", {}, {memrefType});
OpBuilder builder(f.getBody());
@@ -853,7 +803,8 @@ TEST_FUNC(linalg_pointwise_test) {
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f = makeFunction("linalg_pointwise", {},
{memrefType, memrefType, memrefType});
@@ -887,7 +838,8 @@ TEST_FUNC(linalg_matmul_test) {
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
auto f =
makeFunction("linalg_matmul", {}, {memrefType, memrefType, memrefType});
@@ -917,7 +869,10 @@ TEST_FUNC(linalg_conv_nhwc) {
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0);
+ auto memrefType =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
+ ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+ f32Type, {}, 0);
auto f = makeFunction("linalg_conv_nhwc", {},
{memrefType, memrefType, memrefType});
@@ -948,7 +903,10 @@ TEST_FUNC(linalg_dilated_conv_nhwc) {
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1, -1, -1}, f32Type, {}, 0);
+ auto memrefType =
+ MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize,
+ ShapedType::kDynamicSize, ShapedType::kDynamicSize},
+ f32Type, {}, 0);
auto f = makeFunction("linalg_dilated_conv_nhwc", {},
{memrefType, memrefType, memrefType});
@@ -1029,8 +987,10 @@ TEST_FUNC(linalg_tensors_test) {
using namespace edsc::ops;
auto f32Type = FloatType::getF32(&globalContext());
- auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0);
- auto tensorType = RankedTensorType::get({-1, -1}, f32Type);
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0);
+ auto tensorType = RankedTensorType::get(
+ {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type);
auto f = makeFunction("linalg_tensors", {}, {tensorType, memrefType});
OpBuilder builder(f.getBody());