diff options
author | Alex Zinenko <zinenko@google.com> | 2020-08-19 18:38:56 +0200 |
---|---|---|
committer | Alex Zinenko <zinenko@google.com> | 2020-08-19 18:50:19 +0200 |
commit | da562974628017ae92c451ca064fea5b59ad71a4 (patch) | |
tree | e76a7d1cb499a2eb6d7de745fa6bd302133701ef /mlir/test/CAPI | |
parent | 0f95e73190c9a555c2917a2963eab128c4ba5395 (diff) | |
download | llvm-da562974628017ae92c451ca064fea5b59ad71a4.zip llvm-da562974628017ae92c451ca064fea5b59ad71a4.tar.gz llvm-da562974628017ae92c451ca064fea5b59ad71a4.tar.bz2 |
[mlir] expose standard attributes to C API
Provide C API for MLIR standard attributes. Since standard attributes live
under lib/IR in core MLIR, place the C APIs in the IR library as well (standard
ops will go in a separate library).
Affine map and integer set attributes are only exposed as placeholder types
with IsA support due to the lack of C APIs for the corresponding types.
Integer and floating point attribute APIs expecting APInt and APFloat are not
exposed pending decision on how to support APInt and APFloat.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D86143
Diffstat (limited to 'mlir/test/CAPI')
-rw-r--r-- | mlir/test/CAPI/ir.c | 242 |
1 files changed, 241 insertions, 1 deletions
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index 12dc100..0a8ebae 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -12,11 +12,14 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" #include <assert.h> +#include <math.h> #include <stdio.h> #include <stdlib.h> +#include <string.h> void populateLoopBody(MlirContext ctx, MlirBlock loopBody, MlirLocation location, MlirBlock funcBody) { @@ -380,6 +383,210 @@ static int printStandardTypes(MlirContext ctx) { return 0; } +void callbackSetFixedLengthString(const char *data, intptr_t len, + void *userData) { + strncpy(userData, data, len); +} + +int printStandardAttributes(MlirContext ctx) { + MlirAttribute floating = + mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0); + if (!mlirAttributeIsAFloat(floating) || + fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6) + return 1; + mlirAttributeDump(floating); + + MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42); + if (!mlirAttributeIsAInteger(integer) || + mlirIntegerAttrGetValueInt(integer) != 42) + return 2; + mlirAttributeDump(integer); + + MlirAttribute boolean = mlirBoolAttrGet(ctx, 1); + if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean)) + return 3; + mlirAttributeDump(boolean); + + const char data[] = "abcdefghijklmnopqestuvwxyz"; + char buffer[10]; + MlirAttribute opaque = + mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx)); + if (!mlirAttributeIsAOpaque(opaque) || + strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque))) + return 4; + mlirOpaqueAttrGetData(opaque, callbackSetFixedLengthString, buffer); + if (buffer[0] != 'a' || buffer[1] != 'b' || buffer[2] != 'c') + return 5; + mlirAttributeDump(opaque); + + MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3); + if (!mlirAttributeIsAString(string)) + return 6; + mlirStringAttrGetValue(string, callbackSetFixedLengthString, buffer); + if (buffer[0] != 'd' || buffer[1] != 'e') + return 7; + mlirAttributeDump(string); + + MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5); + if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef)) + return 8; + mlirFloatSymbolRefAttrGetValue(flatSymbolRef, callbackSetFixedLengthString, + buffer); + if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h') + return 9; + mlirAttributeDump(flatSymbolRef); + + MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef}; + MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols); + if (!mlirAttributeIsASymbolRef(symbolRef) || + mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 || + !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0), + flatSymbolRef) || + !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1), + flatSymbolRef)) + return 10; + mlirSymbolRefAttrGetLeafReference(symbolRef, callbackSetFixedLengthString, + buffer); + mlirSymbolRefAttrGetRootReference(symbolRef, callbackSetFixedLengthString, + buffer + 3); + if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h' || + buffer[3] != 'i' || buffer[4] != 'j') + return 11; + mlirAttributeDump(symbolRef); + + MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx)); + if (!mlirAttributeIsAType(type) || + !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type))) + return 12; + mlirAttributeDump(type); + + MlirAttribute unit = mlirUnitAttrGet(ctx); + if (!mlirAttributeIsAUnit(unit)) + return 13; + mlirAttributeDump(unit); + + int64_t shape[] = {1, 2}; + + int bools[] = {0, 1}; + uint32_t uints32[] = {0u, 1u}; + int32_t ints32[] = {0, 1}; + uint64_t uints64[] = {0u, 1u}; + int64_t ints64[] = {0, 1}; + float floats[] = {0.0f, 1.0f}; + double doubles[] = {0.0, 1.0}; + MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools); + MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2, + uints32); + MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2, + ints32); + MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2, + uints64); + MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2, + ints64); + MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats); + MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles); + + if (!mlirAttributeIsADenseElements(boolElements) || + !mlirAttributeIsADenseElements(uint32Elements) || + !mlirAttributeIsADenseElements(int32Elements) || + !mlirAttributeIsADenseElements(uint64Elements) || + !mlirAttributeIsADenseElements(int64Elements) || + !mlirAttributeIsADenseElements(floatElements) || + !mlirAttributeIsADenseElements(doubleElements)) + return 14; + + if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || + mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 || + mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 || + mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 || + mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 || + fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) > + 1E-6f || + fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6) + return 15; + + mlirAttributeDump(boolElements); + mlirAttributeDump(uint32Elements); + mlirAttributeDump(int32Elements); + mlirAttributeDump(uint64Elements); + mlirAttributeDump(int64Elements); + mlirAttributeDump(floatElements); + mlirAttributeDump(doubleElements); + + MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1); + MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f); + MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0); + + if (!mlirAttributeIsADenseElements(splatBool) || + !mlirDenseElementsAttrIsSplat(splatBool) || + !mlirAttributeIsADenseElements(splatUInt32) || + !mlirDenseElementsAttrIsSplat(splatUInt32) || + !mlirAttributeIsADenseElements(splatInt32) || + !mlirDenseElementsAttrIsSplat(splatInt32) || + !mlirAttributeIsADenseElements(splatUInt64) || + !mlirDenseElementsAttrIsSplat(splatUInt64) || + !mlirAttributeIsADenseElements(splatInt64) || + !mlirDenseElementsAttrIsSplat(splatInt64) || + !mlirAttributeIsADenseElements(splatFloat) || + !mlirDenseElementsAttrIsSplat(splatFloat) || + !mlirAttributeIsADenseElements(splatDouble) || + !mlirDenseElementsAttrIsSplat(splatDouble)) + return 16; + + if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 || + mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 || + mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 || + mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 || + mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 || + fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) > + 1E-6f || + fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6) + return 17; + + mlirAttributeDump(splatBool); + mlirAttributeDump(splatUInt32); + mlirAttributeDump(splatInt32); + mlirAttributeDump(splatUInt64); + mlirAttributeDump(splatInt64); + mlirAttributeDump(splatFloat); + mlirAttributeDump(splatDouble); + + mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64)); + mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64)); + + int64_t indices[] = {4, 7}; + int64_t two = 2; + MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2, + indices); + MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( + mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats); + MlirAttribute sparseAttr = mlirSparseElementsAttribute( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr, + valuesAttr); + mlirAttributeDump(sparseAttr); + + return 0; +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -454,10 +661,43 @@ int main() { // CHECK: tuple<memref<*xf32, 4>, f32> // CHECK: 0 // clang-format on - fprintf(stderr, "@types"); + fprintf(stderr, "@types\n"); int errcode = printStandardTypes(ctx); fprintf(stderr, "%d\n", errcode); + // clang-format off + // CHECK-LABEL: @attrs + // CHECK: 2.000000e+00 : f64 + // CHECK: 42 : i32 + // CHECK: true + // CHECK: #std.abc + // CHECK: "de" + // CHECK: @fgh + // CHECK: @ij::@fgh::@fgh + // CHECK: f32 + // CHECK: unit + // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> + // CHECK: dense<true> : tensor<1x2xi1> + // CHECK: dense<1> : tensor<1x2xi32> + // CHECK: dense<1> : tensor<1x2xi32> + // CHECK: dense<1> : tensor<1x2xi64> + // CHECK: dense<1> : tensor<1x2xi64> + // CHECK: dense<1.000000e+00> : tensor<1x2xf32> + // CHECK: dense<1.000000e+00> : tensor<1x2xf64> + // CHECK: 1.000000e+00 : f32 + // CHECK: 1.000000e+00 : f64 + // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32> + // clang-format on + fprintf(stderr, "@attrs\n"); + errcode = printStandardAttributes(ctx); + fprintf(stderr, "%d\n", errcode); + mlirContextDestroy(ctx); return 0; |