aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Camphausen <simon.camphausen@iml.fraunhofer.de>2024-04-03 13:06:14 +0200
committerGitHub <noreply@github.com>2024-04-03 13:06:14 +0200
commit1f268092c7af20c21d4a594678b647cab050602a (patch)
tree40e2284e8668209a073682ed5ddd0cb0dd782b38
parentd0dcf06ab8723cc4358ad446354cce875dd89577 (diff)
downloadllvm-1f268092c7af20c21d4a594678b647cab050602a.zip
llvm-1f268092c7af20c21d4a594678b647cab050602a.tar.gz
llvm-1f268092c7af20c21d4a594678b647cab050602a.tar.bz2
[mlir][EmitC] Add support for pointer and opaque types to subscript op (#86266)
For pointer types the indices are restricted to one integer-like operand. For opaque types no further restrictions are made.
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.h6
-rw-r--r--mlir/include/mlir/Dialect/EmitC/IR/EmitC.td30
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp15
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp64
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp2
-rw-r--r--mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir4
-rw-r--r--mlir/test/Dialect/EmitC/invalid_ops.mlir46
-rw-r--r--mlir/test/Dialect/EmitC/ops.mlir7
-rw-r--r--mlir/test/Target/Cpp/subscript.mlir32
9 files changed, 175 insertions, 31 deletions
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bc..c039156 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerIndexOrOpaqueType(Type type);
+
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222..090dae8 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
- [TypesMatchWith<"result type matches element type of 'array'",
- "array", "result",
- "::llvm::cast<ArrayType>($_self).getElementType()">]> {
- let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+ let summary = "Subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
- to variables or arguments of array type.
+ to variables or arguments of array, pointer and opaque type.
Example:
```mlir
%i = index.constant 1
%j = index.constant 7
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+ %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+ %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
```
}];
- let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
- Variadic<IntegerIndexOrOpaqueType>:$indices);
+ let arguments = (ins Arg<AnyTypeOf<[
+ EmitC_ArrayType,
+ EmitC_OpaqueType,
+ EmitC_PointerType]>,
+ "the value to subscript">:$value,
+ Variadic<AnyType>:$indices);
let results = (outs AnyType:$result);
let builders = [
- OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
- build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+ OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+ build($_builder, $_state, array.getType().getElementType(), array, indices);
+ }]>,
+ OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+ build($_builder, $_state, pointer.getType().getPointee(), pointer,
+ ValueRange{index});
}]>
];
let hasVerifier = 1;
- let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+ let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b646..25fa158 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -62,8 +62,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ if (!arrayValue) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+ }
+
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), arrayValue, operands.getIndices());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
@@ -81,9 +87,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ if (!arrayValue) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+ }
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f4a9dc3..7cbf28b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}
+bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
+ return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
+ isSupportedIntegerType(type);
+}
+
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
@@ -780,12 +785,61 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult emitc::SubscriptOp::verify() {
- if (getIndices().size() != (size_t)getArray().getType().getRank()) {
- return emitOpError() << "requires number of indices ("
- << getIndices().size()
- << ") to match the rank of the array type ("
- << getArray().getType().getRank() << ")";
+ // Checks for array operand.
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != (size_t)arrayType.getRank()) {
+ return emitOpError() << "on array operand requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << arrayType.getRank() << ")";
+ }
+ // Check types of index operands.
+ for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+ Type type = getIndices()[i].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on array operand requires index operand " << i
+ << " to be integer-like, but got " << type;
+ }
+ }
+ // Check element type.
+ Type elementType = arrayType.getElementType();
+ if (elementType != getType()) {
+ return emitOpError() << "on array operand requires element type ("
+ << elementType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
}
+
+ // Checks for pointer operand.
+ if (auto pointerType =
+ llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != 1) {
+ return emitOpError()
+ << "on pointer operand requires one index operand, but got "
+ << getIndices().size();
+ }
+ // Check types of index operand.
+ Type type = getIndices()[0].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on pointer operand requires index operand to be "
+ "integer-like, but got "
+ << type;
+ }
+ // Check pointee type.
+ Type pointeeType = pointerType.getPointee();
+ if (pointeeType != getType()) {
+ return emitOpError() << "on pointer operand requires pointee type ("
+ << pointeeType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
+ }
+
+ // The operand has opaque type, so we can't assume anything about the number
+ // or types of index operands.
return success();
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0b07b4b..ee87c1d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1104,7 +1104,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getArray());
+ ss << getOrCreateName(op.getValue());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d..7aa2ba8 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf..bbaab0d 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
// -----
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
- // expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
- %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
+ // expected-error @+1 {{'emitc.subscript' op on array operand requires number of indices (1) to match the rank of the array type (2)}}
+ %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+ // expected-error @+1 {{'emitc.subscript' op on array operand requires index operand 1 to be integer-like, but got 'f32'}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op on array operand requires element type ('f32') and result type ('i32') to match}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
+ // expected-error @+1 {{'emitc.subscript' op on pointer operand requires one index operand, but got 2}}
+ %0 = emitc.subscript %arg0[%arg1, %arg1] : (!emitc.ptr<f32>, index, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: f64) {
+ // expected-error @+1 {{'emitc.subscript' op on pointer operand requires index operand to be integer-like, but got 'f64'}}
+ %0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, f64) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
+ // expected-error @+1 {{'emitc.subscript' op on pointer operand requires pointee type ('f32') and result type ('f64') to match}}
+ %0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, index) -> f64
return
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a29..ace3670 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
return
}
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+ %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+ %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ return
+}
+
emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df..0b38895 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
- %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
- %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+ %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
emitc.assign %0 : f32 to %1 : f32
return
}
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+ emitc.assign %0 : f32 to %1 : f32
+ return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+ return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME: char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
emitc.func @func1(%arg0 : f32) {
emitc.return
}
emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
%k: i8) {
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
- %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+ %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+ %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
emitc.call @func1 (%0) : (f32) -> ()
emitc.call_opaque "func2" (%1) : (f32) -> ()