aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMehdi Amini <joker.eph@gmail.com>2024-03-05 18:57:45 -0800
committerGitHub <noreply@github.com>2024-03-05 18:57:45 -0800
commit96fc54828a0e72e60f90d237e571b47cad5bab87 (patch)
tree2beb21021ef4eed0e760e655c133619d5c9f3e0c
parent85388a06b6022d0a7bc984bcaff86cf96f045338 (diff)
downloadllvm-96fc54828a0e72e60f90d237e571b47cad5bab87.zip
llvm-96fc54828a0e72e60f90d237e571b47cad5bab87.tar.gz
llvm-96fc54828a0e72e60f90d237e571b47cad5bab87.tar.bz2
Revert "[mlir][py] better support for arith.constant construction" (#84103)
Reverts llvm/llvm-project#83259 This broke an integration test on Windows
-rw-r--r--mlir/python/mlir/dialects/arith.py23
-rw-r--r--mlir/test/python/dialects/arith_dialect.py38
2 files changed, 2 insertions, 59 deletions
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 83a50c7..61c6917 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -5,8 +5,6 @@
from ._arith_ops_gen import *
from ._arith_ops_gen import _Dialect
from ._arith_enum_gen import *
-from array import array as _array
-from typing import overload
try:
from ..ir import *
@@ -45,30 +43,13 @@ def _is_float_type(type: Type):
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
- @overload
- def __init__(self, value: Attribute, *, loc=None, ip=None):
- ...
-
- @overload
def __init__(
- self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None
+ self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
):
- ...
-
- def __init__(self, result, value, *, loc=None, ip=None):
- if value is None:
- assert isinstance(result, Attribute)
- super().__init__(result, loc=loc, ip=ip)
- return
-
if isinstance(value, int):
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
elif isinstance(value, float):
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
- elif isinstance(value, _array) and value.typecode in ["i", "l"]:
- super().__init__(DenseIntElementsAttr.get(value, type=result))
- elif isinstance(value, _array) and value.typecode in ["f", "d"]:
- super().__init__(DenseFPElementsAttr.get(value, type=result))
else:
super().__init__(value, loc=loc, ip=ip)
@@ -98,6 +79,6 @@ class ConstantOp(ConstantOp):
def constant(
- result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None
+ result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index ef0e162..8bb80ee 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -4,7 +4,6 @@ from functools import partialmethod
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
-from array import array
def run(f):
@@ -93,40 +92,3 @@ def testArithValue():
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)
-
-
-# CHECK-LABEL: TEST: testArrayConstantConstruction
-@run
-def testArrayConstantConstruction():
- with Context(), Location.unknown():
- module = Module.create()
- with InsertionPoint(module.body):
- i32_array = array("i", [1, 2, 3, 4])
- i32 = IntegerType.get_signless(32)
- vec_i32 = VectorType.get([2, 2], i32)
- arith.constant(vec_i32, i32_array)
- arith.ConstantOp(vec_i32, DenseIntElementsAttr.get(i32_array, type=vec_i32))
-
- i64_array = array("l", [5, 6, 7, 8])
- i64 = IntegerType.get_signless(64)
- vec_i64 = VectorType.get([1, 4], i64)
- arith.constant(vec_i64, i64_array)
- arith.ConstantOp(vec_i64, DenseIntElementsAttr.get(i64_array, type=vec_i64))
-
- f32_array = array("f", [1.0, 2.0, 3.0, 4.0])
- f32 = F32Type.get()
- vec_f32 = VectorType.get([4, 1], f32)
- arith.constant(vec_f32, f32_array)
- arith.ConstantOp(vec_f32, DenseFPElementsAttr.get(f32_array, type=vec_f32))
-
- f64_array = array("d", [1.0, 2.0, 3.0, 4.0])
- f64 = F64Type.get()
- vec_f64 = VectorType.get([2, 1, 2], f64)
- arith.constant(vec_f64, f64_array)
- arith.ConstantOp(vec_f64, DenseFPElementsAttr.get(f64_array, type=vec_f64))
-
- # CHECK-COUNT-2: arith.constant dense<[{{\[}}1, 2], [3, 4]]> : vector<2x2xi32>
- # CHECK-COUNT-2: arith.constant dense<[{{\[}}5, 6, 7, 8]]> : vector<1x4xi64>
- # CHECK-COUNT-2: arith.constant dense<[{{\[}}1.000000e+00], [2.000000e+00], [3.000000e+00], [4.000000e+00]]> : vector<4x1xf32>
- # CHECK-COUNT-2: arith.constant dense<[{{\[}}[1.000000e+00, 2.000000e+00]], [{{\[}}3.000000e+00, 4.000000e+00]]]> : vector<2x1x2xf64>
- print(module)