aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp62
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp21
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp7
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp13
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp96
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp95
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp29
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp377
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp48
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp22
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp25
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp6
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp25
14 files changed, 691 insertions, 139 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db..4c4965e 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
- Type sourceAType = getSourceA().getType();
- Type sourceBType = getSourceB().getType();
- Type destType = getDestC().getType();
- VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
- VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
- VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ SmallVector<int64_t, 3> dimensions;
+ if (parser.parseDimensionList(dimensions, false, false))
+ return failure();
+ if (dimensions.size() != 3)
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 3 dimensions in MNK dimension list";
- Type sourceAElemType = sourceVectorAType.getElementType();
- Type sourceBElemType = sourceVectorBType.getElementType();
- Type destElemType = destVectorType.getElementType();
+ m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+ n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+ k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+ return success();
+}
- if (sourceVectorAType.getNumElements() !=
- sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sourceAElemType = sourceAType.getElementType();
+ Type sourceBElemType = sourceBType.getElementType();
+ if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have different lengths: ")
- << sourceVectorAType << " vs. " << sourceVectorBType;
+ << sourceAType << " vs. " << sourceBType;
}
- bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
- bool isSrcFloat =
- isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
- sourceAElemType);
-
- if (isDestFloat && !isSrcFloat) {
- return emitOpError("Expected float sources with float destination");
- }
+ bool isDestFloat = destType.getElementType().isFloat();
+ bool isSrcFloat = sourceAElemType.isFloat();
- if (!isDestFloat && isSrcFloat) {
- return emitOpError("Expected int sources with int destination");
- }
+ if (isDestFloat && !isSrcFloat)
+ return emitOpError("expected float sources with float destination");
+ if (!isDestFloat && isSrcFloat)
+ return emitOpError("expected int sources with int destination");
- if (sourceAElemType != sourceBElemType &&
- !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
- isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+ if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}
+
+ if (!sourceAElemType.isInteger(4) && getK() != 16) {
+ return emitOpError("K dimension must be 16 for source element type ")
+ << sourceAElemType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
index 316721b..60ae78b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
@@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> {
const std::function<unsigned(AffineForOp)> getUnrollFactor;
LoopUnroll() : getUnrollFactor(nullptr) {}
- LoopUnroll(const LoopUnroll &other)
-
- = default;
+ LoopUnroll(const LoopUnroll &other) = default;
explicit LoopUnroll(
std::optional<unsigned> unrollFactor = std::nullopt,
- bool unrollUpToFactor = false, bool unrollFull = false,
+ bool unrollUpToFactor = false,
const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr)
: getUnrollFactor(getUnrollFactor) {
if (unrollFactor)
this->unrollFactor = *unrollFactor;
this->unrollUpToFactor = unrollUpToFactor;
- this->unrollFull = unrollFull;
}
void runOnOperation() override;
@@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f,
}
void LoopUnroll::runOnOperation() {
+ if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
+ emitError(UnknownLoc::get(&getContext()),
+ "Invalid option: 'unroll-factor' should be greater than 0 or "
+ "equal to -1");
+ return signalPassFailure();
+ }
FunctionOpInterface func = getOperation();
if (func.isExternal())
return;
- if (unrollFull && unrollFullThreshold.hasValue()) {
+ if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) {
// Store short loops as we walk.
SmallVector<AffineForOp, 4> loops;
@@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
return loopUnrollByFactor(forOp, getUnrollFactor(forOp),
/*annotateFn=*/nullptr, cleanUpUnroll);
// Unroll completely if full loop unroll was specified.
- if (unrollFull)
+ if (unrollFactor.getValue() == -1)
return loopUnrollFull(forOp);
// Otherwise, unroll by the given unroll factor.
if (unrollUpToFactor)
@@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::affine::createLoopUnrollPass(
- int unrollFactor, bool unrollUpToFactor, bool unrollFull,
+ int unrollFactor, bool unrollUpToFactor,
const std::function<unsigned(AffineForOp)> &getUnrollFactor) {
return std::make_unique<LoopUnroll>(
unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
- unrollUpToFactor, unrollFull, getUnrollFactor);
+ unrollUpToFactor, getUnrollFactor);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index a6159ee..f0ddb50 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -14,13 +14,6 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
-namespace mlir {
-namespace bufferization {
-#define GEN_PASS_DEF_TENSORCOPYINSERTION
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
-} // namespace bufferization
-} // namespace mlir
-
using namespace mlir;
using namespace mlir::bufferization;
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 11400de..a15bf89 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -59,6 +59,17 @@ struct DimOpInterface
}
};
+struct ExpandShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface,
+ memref::ExpandShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto expandOp = cast<memref::ExpandShapeOp>(op);
+ assert(value == expandOp.getResult() && "invalid value");
+ cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ }
+};
+
struct GetGlobalOpInterface
: public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface,
GetGlobalOp> {
@@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
+ *ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx);
memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
index a9125ec..c117d9b 100644
--- a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -189,7 +189,9 @@ class PrepareForOMPOffloadPrivatizationPass
DominanceInfo dom;
llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
- return dom.dominates(l, r);
+ if (l == r)
+ return false;
+ return dom.properlyDominates(l, r);
});
rewriter.setInsertionPoint(chainOfOps.front());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9bd13f3..744a595 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -27,6 +27,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/DebugLog.h"
@@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
}
};
+// Pattern to eliminate ExecuteRegionOp results which forward external
+// values from the region. In case there are multiple yield operations,
+// all of them must have the same operands in order for the pattern to be
+// applicable.
+struct ExecuteRegionForwardingEliminator
+ : public OpRewritePattern<ExecuteRegionOp> {
+ using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExecuteRegionOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getNumResults() == 0)
+ return failure();
+
+ SmallVector<Operation *> yieldOps;
+ for (Block &block : op.getRegion()) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yield.getOperation());
+ }
+
+ if (yieldOps.empty())
+ return failure();
+
+ // Check if all yield operations have the same operands.
+ auto yieldOpsOperands = yieldOps[0]->getOperands();
+ for (auto *yieldOp : yieldOps) {
+ if (yieldOp->getOperands() != yieldOpsOperands)
+ return failure();
+ }
+
+ SmallVector<Value> externalValues;
+ SmallVector<Value> internalValues;
+ SmallVector<Value> opResultsToReplaceWithExternalValues;
+ SmallVector<Value> opResultsToKeep;
+ for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) {
+ if (isValueFromInsideRegion(yieldedValue, op)) {
+ internalValues.push_back(yieldedValue);
+ opResultsToKeep.push_back(op.getResult(index));
+ } else {
+ externalValues.push_back(yieldedValue);
+ opResultsToReplaceWithExternalValues.push_back(op.getResult(index));
+ }
+ }
+ // No yielded external values - nothing to do.
+ if (externalValues.empty())
+ return failure();
+
+ // There are yielded external values - create a new execute_region returning
+ // just the internal values.
+ SmallVector<Type> resultTypes;
+ for (Value value : internalValues)
+ resultTypes.push_back(value.getType());
+ auto newOp =
+ ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes));
+ newOp->setAttrs(op->getAttrs());
+
+ // Move old op's region to the new operation.
+ rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().end());
+
+ // Replace all yield operations with a new yield operation with updated
+ // results. scf.execute_region must have at least one yield operation.
+ for (auto *yieldOp : yieldOps) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp,
+ ValueRange(internalValues));
+ }
+
+ // Replace the old operation with the external values directly.
+ rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues,
+ externalValues);
+ // Replace the old operation's remaining results with the new operation's
+ // results.
+ rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults());
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ bool isValueFromInsideRegion(Value value,
+ ExecuteRegionOp executeRegionOp) const {
+ // Check if the value is defined within the execute_region
+ if (Operation *defOp = value.getDefiningOp())
+ return &executeRegionOp.getRegion() == defOp->getParentRegion();
+
+ // If it's a block argument, check if it's from within the region
+ if (BlockArgument blockArg = dyn_cast<BlockArgument>(value))
+ return &executeRegionOp.getRegion() == blockArg.getParentRegion();
+
+ return false; // Value is from outside the region
+ }
+};
+
void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
+ results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner,
+ ExecuteRegionForwardingEliminator>(context);
}
void ExecuteRegionOp::getSuccessorRegions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index 1cba1bb..eb47e85 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -12,6 +12,97 @@
namespace mlir {
namespace tosa {
+llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
+ return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
+}
+
+TosaSpecificationVersion getMinVersion(const Profile &profile) {
+ switch (profile) {
+ case Profile::pro_int:
+ case Profile::pro_fp:
+ return TosaSpecificationVersion(1, 0);
+ case Profile::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA profile");
+}
+
+TosaSpecificationVersion getMinVersion(const Extension &extension) {
+ switch (extension) {
+ case Extension::int16:
+ case Extension::int4:
+ case Extension::bf16:
+ case Extension::fp8e4m3:
+ case Extension::fp8e5m2:
+ case Extension::fft:
+ case Extension::variable:
+ case Extension::controlflow:
+ case Extension::doubleround:
+ case Extension::inexactround:
+ case Extension::dynamic:
+ return TosaSpecificationVersion(1, 0);
+ case Extension::mxfp:
+ case Extension::int64:
+ return TosaSpecificationVersion(1, 1);
+ case Extension::none:
+ return TosaSpecificationVersion(0, 0);
+ }
+ llvm_unreachable("Unknown TOSA extension");
+}
+
+TosaSpecificationVersion getMinVersion(const Level &level) {
+ switch (level) {
+ case Level::eightK:
+ case Level::none:
+ return TosaSpecificationVersion(1, 0);
+ }
+ llvm_unreachable("Unknown TOSA level");
+}
+
+FailureOr<TargetEnv>
+TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
+ Location targetEnvAttrLoc) {
+ if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
+ return failure();
+
+ return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
+ targetAttr.getProfiles(), targetAttr.getExtensions());
+}
+
+LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
+ Location targetAttrLoc) {
+ TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());
+
+ const auto isCompatibleWithTargetVersion =
+ [&](const auto &targetEnum, Location targetAttrLoc,
+ StringRef enumName) -> LogicalResult {
+ const TosaSpecificationVersion minRequiredVersion =
+ getMinVersion(targetEnum);
+ if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
+ return emitError(targetAttrLoc, enumName)
+ << " '" << stringifyEnum(targetEnum)
+ << "' is not compatible with the target version "
+ << stringifyVersion(targetVersion)
+ << ", minimum required version is "
+ << stringifyVersion(minRequiredVersion);
+ return success();
+ };
+
+ for (const auto &profile : targetAttr.getProfiles())
+ if (failed(
+ isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
+ return failure();
+ for (const auto &extension : targetAttr.getExtensions())
+ if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
+ "extension")))
+ return failure();
+ if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
+ "level")))
+ return failure();
+
+ return success();
+}
+
TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
@@ -39,9 +130,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
-llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
- return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
-}
-
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 99b7cda..a85ff10a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -76,28 +76,6 @@ template <typename OpTy>
struct PoolPadFoldAdaptor;
template <>
-struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> {
- using OpTy = tosa::AvgPool2dOp;
- static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
- const llvm::ArrayRef<int64_t> kernel = op.getKernel();
- if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] ||
- newPad[0] >= kernel[0] || newPad[1] >= kernel[0])
- return false;
- return true;
- }
- static bool checkPadConstCompliance(OpTy op, Value padConst) {
- return checkMatchingPadConstAndZp(padConst, op.getInputZp());
- }
- static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op,
- Value padInput, ArrayRef<int64_t> newPad) {
- rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>(
- op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(),
- op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad),
- op.getAccType());
- }
-};
-
-template <>
struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> {
using OpTy = tosa::MaxPool2dOp;
static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) {
@@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> {
};
} // namespace
-void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<FoldPadToTensorOp<tosa::AvgPool2dOp,
- PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
- context);
-}
-
void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 00f84bc..0aff67f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
}
}
+ // special handling: block_size accepts a *bare* BlockSizeMode enum
+ if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+ if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeBlockSize(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid block_size value: " << kw;
+ auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
// Default path: parse any normal attribute literal, including fully qualified
// enum keyword
Attribute attr;
@@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
} else if (auto nanPropagationModeAttr =
dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
parser << nanPropagationModeAttr.getValue();
+ } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+ parser << blockSizeAttr.getValue();
} else {
parser.printAttribute(attr);
}
@@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
// verify that inType and outType have same element types
template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
- auto inputType = llvm::dyn_cast<TensorType>(inType);
- auto outputType = llvm::dyn_cast<TensorType>(outType);
- if (!inputType) {
- op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+ StringRef aName = "input",
+ StringRef bName = "output") {
+ auto aTType = llvm::dyn_cast<TensorType>(aType);
+ auto bTType = llvm::dyn_cast<TensorType>(bType);
+ if (!aTType) {
+ op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
return failure();
}
- if (!outputType) {
- op.emitOpError("expect shaped tensor for output, got ") << outType;
+ if (!bTType) {
+ op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
return failure();
}
- auto inputElementType = inputType.getElementType();
- auto outputElementType = outputType.getElementType();
- auto inputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
- auto outputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
- if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
- (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
- inputElementType != outputElementType) {
+ auto aElementType = aTType.getElementType();
+ auto bElementType = bTType.getElementType();
+ auto aQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+ auto bQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+ if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+ (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+ aElementType != bElementType) {
// only check if both element types are int/index/float/UniformQuantized
// eg, not sure how to check quant::QuantizedType
// this happens in test_conv2d_q_grouped_convolution in
// tfl-to-tosa-pipeline.mlir
- op.emitOpError("expect input and output to have same element type, got ")
- << inputElementType << " and " << outputElementType;
+ op.emitOpError("expect ")
+ << aName << " and " << bName << " to have same element type, got "
+ << aElementType << " and " << bElementType;
return failure();
}
return success();
@@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() {
return success();
}
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MatmulTBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+ const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+ if (aDataShape.hasRank()) {
+ outShape[0] = aDataShape.getDimSize(0);
+ outShape[1] = aDataShape.getDimSize(1);
+ }
+
+ const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+ : outShape[0];
+ outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+ : outShape[1];
+ }
+
+ // If B batch size is 1, it is broadcast across A's batch size
+ const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+ if (bDataShape.hasRank()) {
+ const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+ if (bDataBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+ outShape[2] = bDataShape.getDimSize(1);
+ }
+
+ const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+ if (bScaleBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+ outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+ : outShape[2];
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+ // Verify same input data types
+ const Type aDataType = getAData().getType();
+ const Type bDataType = getBData().getType();
+ if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+ "B_data")))
+ return failure();
+
+ auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) -> LogicalResult {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+ };
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t D = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+ if (aDataShape.hasRank()) {
+ N = aDataShape.getDimSize(0);
+ H = aDataShape.getDimSize(1);
+ C = aDataShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+ "height")))
+ return failure();
+ multiplesOfC = aScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+ if (bDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+ "channels")))
+ return failure();
+ W = bDataShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+ "width")) ||
+ failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+ "b_scale", "C/block_size")))
+ return failure();
+ }
+
+ // Verify batch size is broadcast compatible
+ if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+ return emitOpError("expect B matrix batch size to be broadcast compatible "
+ "with A, got D=")
+ << D << " vs N=" << N;
+
+ // Verify C is a multiple of block size
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(C) && C % blockSize != 0)
+ return emitOpError("expect C to be a multiple of block size, got C=")
+ << C << ", block_size=" << blockSize;
+
+ // Verify multiplesOfC is C / block size
+ if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != C / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal C/block_size (")
+ << C << "/" << blockSize << ")"
+ << ", got " << multiplesOfC;
+
+ // Verify output shape
+ N = ShapedType::isDynamic(N) ? D : N;
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index f072e3e..ab363ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -19,12 +19,19 @@ TosaProfileCompliance::TosaProfileCompliance() {
const TypeInfo i16T = {mlir::IntegerType::getTypeID(), 16};
const TypeInfo i32T = {mlir::IntegerType::getTypeID(), 32};
const TypeInfo i48T = {mlir::IntegerType::getTypeID(), 48};
+ const TypeInfo i64T = {mlir::IntegerType::getTypeID(), 64};
const TypeInfo bf16T = {mlir::BFloat16Type::getTypeID(), 16};
const TypeInfo fp16T = {mlir::Float16Type::getTypeID(), 16};
const TypeInfo fp32T = {mlir::Float32Type::getTypeID(), 32};
const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8};
const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8};
+ // micro-scaling formats
+ const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6};
+ const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
+ const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
+ const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
+
// The profile-based compliance content below is auto-generated by a script
// in https://git.mlplatform.org/tosa/specification.git
#include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc"
@@ -44,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
- Value output) {
- for (auto operand : operands)
+ ValueRange results) {
+ for (const auto &operand : operands)
addValue(operand);
- addValue(output);
+ for (const auto &result : results)
+ addValue(result);
return success();
}
@@ -170,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
}
template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getInputImag());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
addValue(op.getOnFalse());
@@ -239,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- return populateProfileInfo(op->getOperands(), op->getResult(0)); \
+ return populateProfileInfo(op->getOperands(), op->getResults()); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -250,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
- POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
- POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -269,7 +258,12 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
+ POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(FFT2d)
+ POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
+ POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
@@ -623,6 +617,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
return {"fp8e4m3"};
} else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) {
return {"fp8e5m2"};
+ } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) {
+ return {"fp6e2m3"};
+ } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) {
+ return {"fp6e3m2"};
+ } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) {
+ return {"fp4e2m1"};
+ } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
+ return {"fp8e8m0"};
}
llvm_unreachable("unknown type");
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 82f2f7e..4d0b61a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
+ CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+ CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
@@ -657,6 +659,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(TransposeConv2D);
CHECK_SIZES(FFT2d);
CHECK_SIZES(MatMul);
+ CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
@@ -1192,9 +1195,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
if (isa<FloatType>(type)) {
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
- Float8E5M2Type>(type);
- }
- if (auto intTy = dyn_cast<IntegerType>(type)) {
+ Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType,
+ Float6E3M2FNType, Float8E8M0FNUType>(type);
+ } else if (auto intTy = dyn_cast<IntegerType>(type)) {
if (intTy.isSignless()) {
switch (intTy.getWidth()) {
case 1:
@@ -1203,6 +1206,7 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
case 16:
case 32:
case 48:
+ case 64:
return true;
}
} else if (allowUnsigned && intTy.isUnsigned()) {
@@ -1220,13 +1224,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
}
void TosaValidation::runOnOperation() {
+ ModuleOp modOp = getOperation();
+ const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp);
+ const auto maybeTargetEnv =
+ tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc());
+ if (failed(maybeTargetEnv))
+ return signalPassFailure();
+ targetEnv = *maybeTargetEnv;
+
TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
if (!tosaDialect)
return;
- targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation()));
-
- getOperation().walk([&](Operation *op) {
+ modOp.walk([&](Operation *op) {
if (op->getDialect() != tosaDialect)
return;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784a..2c37140 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(operandOrResult);
if (layout && layout.isForSubgroup()) {
- if (!layout.getEffectiveInstDataAsInt().empty())
- return layout.getEffectiveInstDataAsInt();
+ if (!layout.getEffectiveInstDataAsInt().empty()) {
+ SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+ // Remove leading unit dimensions from inst_data
+ // For example, if the inst_data is [1, 1, 32]
+ // it will pass [32] as the unroll/blocking size.
+ // Skip it for xegpu nd ops since it will be 2D
+ // TODO: For vectors ops, experiment with the
+ // upstream vector remove leading unit dims patterns,
+ // populateCastAwayVectorLeadingOneDimPatterns.
+ Operation *definingOp = value.getDefiningOp();
+ bool skipLeadingUnitDimRemoval =
+ definingOp &&
+ (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
+ xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
+ if (!skipLeadingUnitDimRemoval) {
+ auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
+ instData.erase(instData.begin(), it);
+ }
+ return instData;
+ }
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() {
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
-
encoding = newEncoding;
}
}
@@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
tdescTy.getLayoutAttr().dropInstData());
} else {
- newTy = type.clone(tileShape, elemTy);
+ newTy = VectorType::get(tileShape, elemTy);
}
if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index aafa1b7..e6e71cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ protected:
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
Location loc, PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of destTy.");
auto shape = vecTy.getShape();
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
}
@@ -93,8 +91,6 @@ protected:
ArrayRef<int64_t> blockSize, Location loc,
PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of src.");
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
blockSize);
}
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
Type elemTy = valueTy.getElementType();
- VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+ VectorType newValueTy = VectorType::get(*targetShape, elemTy);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a43..b4605cd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
if (!computeShapeRatio(srcShape, shape))
return {value};
+ int64_t srcShapeRank = srcShape.size();
+ int64_t targetShapeRank = shape.size();
+
+ SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+ int64_t rankDiff = srcShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+ 1);
+ std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
SmallVector<Value> result;
- for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
SmallVector<int64_t> staticStrides(offsets.size(), 1);
- result.push_back(vector::ExtractStridedSliceOp::create(
- builder, loc, value, offsets, shape, staticStrides));
+ Value slice = vector::ExtractStridedSliceOp::create(
+ builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+ // Reshape to remove leading unit dims if needed
+ if (srcShapeRank > targetShapeRank) {
+ auto targetTy = VectorType::get(shape, vecTy.getElementType());
+ slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
+ }
+ result.push_back(slice);
}
return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
for (auto [src, offsets] :
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
- SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ SmallVector<int64_t> staticStrides(tileShape.size(), 1);
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
offsets, staticStrides);
}