diff options
| -rw-r--r-- | mlir/python/mlir/runtime/np_to_memref.py | 12 | ||||
| -rw-r--r-- | mlir/test/python/execution_engine.py | 41 |
2 files changed, 52 insertions, 1 deletions
diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 882b275..8cca1e7 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -37,6 +37,11 @@ class BF16(ctypes.Structure): _fields_ = [("bf16", ctypes.c_int16)] +class F8E5M2(ctypes.Structure): + """A ctype representation for MLIR's Float8E5M2.""" + + _fields_ = [("f8E5M2", ctypes.c_int8)] + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): @@ -49,6 +54,8 @@ def as_ctype(dtp): return F16 if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: return BF16 + if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: + return F8E5M2 return np.ctypeslib.as_ctypes_type(dtp) @@ -65,6 +72,11 @@ def to_numpy(array): ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == BF16: return array.view("bfloat16") + assert not ( + array.dtype == F8E5M2 and ml_dtypes is None + ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == F8E5M2: + return array.view("float8_e5m2") return array diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 8125bf3..1cdda63 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -5,7 +5,7 @@ from mlir.ir import * from mlir.passmanager import * from mlir.execution_engine import * from mlir.runtime import * -from ml_dtypes import bfloat16 +from ml_dtypes import bfloat16, float8_e5m2 # Log everything to stderr and flush so that we have a unified stream to match @@ -561,6 +561,45 @@ def testBF16Memref(): run(testBF16Memref) +# Test f8E5M2 memrefs +# CHECK-LABEL: TEST: testF8E5M2Memref +def testF8E5M2Memref(): + with Context(): + module = Module.parse( + """ + module { + func.func @main(%arg0: memref<1xf8E5M2>, + %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } { + %0 = arith.constant 0 : index + %1 = memref.load %arg0[%0] : memref<1xf8E5M2> + memref.store %1, %arg1[%0] : memref<1xf8E5M2> + return + } + } """ + ) + + arg1 = np.array([0.5]).astype(float8_e5m2) + arg2 = np.array([0.0]).astype(float8_e5m2) + + arg1_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg1)) + ) + arg2_memref_ptr = ctypes.pointer( + ctypes.pointer(get_ranked_memref_descriptor(arg2)) + ) + + execution_engine = ExecutionEngine(lowerToLLVM(module)) + execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr) + + # test to-numpy utility + # CHECK: [0.5] + npout = ranked_memref_to_numpy(arg2_memref_ptr[0]) + log(npout) + + +run(testF8E5M2Memref) + + # Test addition of two 2d_memref # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D(): |
