aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToXeVM/CMakeLists.txt22
-rw-r--r--mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp167
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp288
-rw-r--r--mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp7
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp5
9 files changed, 403 insertions, 100 deletions
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f8..bebf1b8 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -40,6 +40,7 @@ add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
+add_subdirectory(MathToXeVM)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
add_subdirectory(MemRefToSPIRV)
diff --git a/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
new file mode 100644
index 0000000..050c0ed
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/CMakeLists.txt
@@ -0,0 +1,22 @@
+add_mlir_conversion_library(MLIRMathToXeVM
+ MathToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithAttrToLLVMConversion
+ MLIRArithDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRMathDialect
+ MLIRXeVMDialect
+ MLIRPass
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
new file mode 100644
index 0000000..0fe31d0
--- /dev/null
+++ b/mlir/lib/Conversion/MathToXeVM/MathToXeVM.cpp
@@ -0,0 +1,167 @@
+//===-- MathToXeVM.cpp - conversion from Math to XeVM ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
+#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/FormatVariadic.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOXEVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-xevm"
+
+/// Convert math ops marked with `fast` (`afn`) to native OpenCL intrinsics.
+template <typename Op>
+struct ConvertNativeFuncPattern final : public OpConversionPattern<Op> {
+
+ ConvertNativeFuncPattern(MLIRContext *context, StringRef nativeFunc,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern<Op>(context, benefit), nativeFunc(nativeFunc) {}
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isSPIRVCompatibleFloatOrVec(op.getType()))
+ return failure();
+
+ arith::FastMathFlags fastFlags = op.getFastmath();
+ if (!arith::bitEnumContainsAll(fastFlags, arith::FastMathFlags::afn))
+ return rewriter.notifyMatchFailure(op, "not a fastmath `afn` operation");
+
+ SmallVector<Type, 1> operandTypes;
+ for (auto operand : adaptor.getOperands()) {
+ Type opTy = operand.getType();
+ // This pass only supports operations on vectors that are already in SPIRV
+ // supported vector sizes: Distributing unsupported vector sizes to SPIRV
+ // supported vector sizes are done in other blocking optimization passes.
+ if (!isSPIRVCompatibleFloatOrVec(opTy))
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv("incompatible operand type: '{0}'", opTy));
+ operandTypes.push_back(opTy);
+ }
+
+ auto moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
+ auto funcOpRes = LLVM::lookupOrCreateFn(
+ rewriter, moduleOp, getMangledNativeFuncName(operandTypes),
+ operandTypes, op.getType());
+ assert(!failed(funcOpRes));
+ LLVM::LLVMFuncOp funcOp = funcOpRes.value();
+
+ auto callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
+ op, funcOp, adaptor.getOperands());
+ // Preserve fastmath flags in our MLIR op when converting to llvm function
+ // calls, in order to allow further fastmath optimizations: We thus need to
+ // convert arith fastmath attrs into attrs recognized by llvm.
+ arith::AttrConvertFastMathToLLVM<Op, LLVM::CallOp> fastAttrConverter(op);
+ mlir::NamedAttribute fastAttr = fastAttrConverter.getAttrs()[0];
+ callOp->setAttr(fastAttr.getName(), fastAttr.getValue());
+ return success();
+ }
+
+ inline bool isSPIRVCompatibleFloatOrVec(Type type) const {
+ if (type.isFloat())
+ return true;
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ if (!vecType.getElementType().isFloat())
+ return false;
+ // SPIRV distinguishes between vectors and matrices: OpenCL native math
+ // intrsinics are not compatible with matrices.
+ ArrayRef<int64_t> shape = vecType.getShape();
+ if (shape.size() != 1)
+ return false;
+ // SPIRV only allows vectors of size 2, 3, 4, 8, 16.
+ if (shape[0] == 2 || shape[0] == 3 || shape[0] == 4 || shape[0] == 8 ||
+ shape[0] == 16)
+ return true;
+ }
+ return false;
+ }
+
+ inline std::string
+ getMangledNativeFuncName(const ArrayRef<Type> operandTypes) const {
+ std::string mangledFuncName =
+ "_Z" + std::to_string(nativeFunc.size()) + nativeFunc.str();
+
+ auto appendFloatToMangledFunc = [&mangledFuncName](Type type) {
+ if (type.isF32())
+ mangledFuncName += "f";
+ else if (type.isF16())
+ mangledFuncName += "Dh";
+ else if (type.isF64())
+ mangledFuncName += "d";
+ };
+
+ for (auto type : operandTypes) {
+ if (auto vecType = dyn_cast<VectorType>(type)) {
+ mangledFuncName += "Dv" + std::to_string(vecType.getShape()[0]) + "_";
+ appendFloatToMangledFunc(vecType.getElementType());
+ } else
+ appendFloatToMangledFunc(type);
+ }
+
+ return mangledFuncName;
+ }
+
+ const StringRef nativeFunc;
+};
+
+void mlir::populateMathToXeVMConversionPatterns(RewritePatternSet &patterns,
+ bool convertArith) {
+ patterns.add<ConvertNativeFuncPattern<math::ExpOp>>(patterns.getContext(),
+ "__spirv_ocl_native_exp");
+ patterns.add<ConvertNativeFuncPattern<math::CosOp>>(patterns.getContext(),
+ "__spirv_ocl_native_cos");
+ patterns.add<ConvertNativeFuncPattern<math::Exp2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_exp2");
+ patterns.add<ConvertNativeFuncPattern<math::LogOp>>(patterns.getContext(),
+ "__spirv_ocl_native_log");
+ patterns.add<ConvertNativeFuncPattern<math::Log2Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log2");
+ patterns.add<ConvertNativeFuncPattern<math::Log10Op>>(
+ patterns.getContext(), "__spirv_ocl_native_log10");
+ patterns.add<ConvertNativeFuncPattern<math::PowFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_powr");
+ patterns.add<ConvertNativeFuncPattern<math::RsqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_rsqrt");
+ patterns.add<ConvertNativeFuncPattern<math::SinOp>>(patterns.getContext(),
+ "__spirv_ocl_native_sin");
+ patterns.add<ConvertNativeFuncPattern<math::SqrtOp>>(
+ patterns.getContext(), "__spirv_ocl_native_sqrt");
+ patterns.add<ConvertNativeFuncPattern<math::TanOp>>(patterns.getContext(),
+ "__spirv_ocl_native_tan");
+ if (convertArith)
+ patterns.add<ConvertNativeFuncPattern<arith::DivFOp>>(
+ patterns.getContext(), "__spirv_ocl_native_divide");
+}
+
+namespace {
+struct ConvertMathToXeVMPass
+ : public impl::ConvertMathToXeVMBase<ConvertMathToXeVMPass> {
+ using Base::Base;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToXeVMPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateMathToXeVMConversionPatterns(patterns, convertArith);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, LLVM::LLVMDialect>();
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 2b7bdc9..11f866c 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -22,6 +22,7 @@
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <cstdint>
#include <numeric>
@@ -110,9 +111,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
{TypeAttr::get(memrefType.getElementType())}));
IndexType indexType = builder.getIndexType();
- int64_t numElements = std::accumulate(memrefType.getShape().begin(),
- memrefType.getShape().end(), int64_t{1},
- std::multiplies<int64_t>());
+ int64_t numElements = llvm::product_of(memrefType.getShape());
emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
builder, loc, indexType, builder.getIndexAttr(numElements));
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a5336ed..00df14b1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1392,6 +1392,137 @@ public:
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+static llvm::SmallVector<int8_t>
+convertToI8(const llvm::SmallVector<int32_t> &input) {
+ llvm::SmallVector<int8_t> output;
+ output.reserve(input.size());
+
+ for (auto v : llvm::map_range(
+ input, [](int32_t val) { return static_cast<int8_t>(val); })) {
+ output.push_back(v);
+ }
+ return output;
+}
+
+// The shift or multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift or multiplier is non-constant, add it as an input to
+// linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift or multiplier is constant, set 'constant' instead.
+static void setupLinalgGenericOpInputAndIndexingMap(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
+ bool isShift = false) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the
+ // values in a buffer.
+ if (values.size() == 1) {
+ IntegerAttr intAttr = isShift
+ ? rewriter.getI8IntegerAttr(values.front())
+ : rewriter.getI32IntegerAttr(values.front());
+ constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ } else {
+ auto elementType =
+ isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
+ auto tensorType = RankedTensorType::get(
+ {static_cast<int64_t>(values.size())}, elementType);
+ DenseIntElementsAttr EltAttr;
+ if (isShift)
+ EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
+ else
+ EltAttr = DenseIntElementsAttr::get(tensorType, values);
+ genericInputs.push_back(
+ arith::ConstantOp::create(rewriter, loc, EltAttr));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto operand = isShift ? op.getShift() : op.getMultiplier();
+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(operand);
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ }
+ arg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t zpArg,
+ bool isOutputZp = false) {
+ Value result;
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ const uint32_t attrBitwidth =
+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
+ auto extendType = builder.getIntegerType(attrBitwidth);
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[zpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
+ // For ExtUIOp, the input must be signless.
+ // UnrealizedConversionCastOp will cast the input to signless type.
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ UnrealizedConversionCastOp::create(
+ builder, loc,
+ builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
+ .getResult(0);
+ }
+ if (zpTy.isUnsignedInteger()) {
+ return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ } else {
+ return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ }
+ }
+ } else {
+ return builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(extendType, *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ public:
}
}
- // The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int32_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ // explicit cast is required here
+ shiftValues = llvm::to_vector(llvm::map_range(
+ shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
-
- bool doubleRound =
- op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
@@ -1468,45 +1605,43 @@ public:
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant, shiftArg, true);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp's are non-constant, add them as an inputs to
+ // linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ // - If the zp's are constant, they would be generated as arith.constant.
+ int64_t iZpArg = 0;
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ iZpArg = indexingMaps.size() - 1;
+ }
+ int64_t oZpArg = 0;
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
+ oZpArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
@@ -1526,36 +1661,17 @@ public:
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs, iZpArg);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs, oZpArg, true);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
index 802691c..9bf9ca3 100644
--- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
+++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include <numeric>
@@ -70,8 +71,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
// Calculate the product of all elements in 'newShape' except for the -1
// placeholder, which we discard by negating the result.
- int64_t totalSizeNoPlaceholder = -std::accumulate(
- newShape.begin(), newShape.end(), 1, std::multiplies<int64_t>());
+ int64_t totalSizeNoPlaceholder = -llvm::product_of(newShape);
// If there is a 0 component in 'newShape', resolve the placeholder as
// 0.
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 79c2f23..245a3ef 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -20,6 +20,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include <numeric>
@@ -265,8 +266,7 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
if (isPacked)
src = collapseLastDim(rewriter, src);
int64_t rows = vecShape[0];
- int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(vecShape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
@@ -336,8 +336,7 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
ArrayRef<int64_t> shape = vecTy.getShape();
int64_t rows = shape[0];
- int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t cols = llvm::product_of(shape.drop_front());
auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index c45c45e..c9eba69 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOSCF
@@ -760,8 +761,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
if (vectorType.getRank() != 1) {
// Flatten n-D vectors to 1D. This is done to allow indexing with a
// non-constant value.
- auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t flatLength = llvm::product_of(shape);
auto flatVectorType =
VectorType::get({flatLength}, vectorType.getElementType());
value = vector::ShapeCastOp::create(rewriter, loc, flatVectorType, value);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d8..71687b1 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -774,9 +775,7 @@ struct ConvertXeGPUToXeVMPass
if (rank < 1 || type.getNumElements() == 1)
return elemType;
// Otherwise, convert the vector to a flat vector type.
- int64_t sum =
- std::accumulate(type.getShape().begin(), type.getShape().end(),
- int64_t{1}, std::multiplies<int64_t>());
+ int64_t sum = llvm::product_of(type.getShape());
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {