aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp3
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp54
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp98
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp65
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp18
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp12
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp100
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp35
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h27
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp12
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp16
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp30
13 files changed, 292 insertions, 180 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 19cc914..337f8bb 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1532,7 +1532,8 @@ public:
auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
- *stringLiteral, *getTypeConverter());
+ *stringLiteral, *getTypeConverter(),
+ /*addNewline=*/false);
} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index e645afe..fc0515b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -195,43 +195,25 @@ DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv,
return res;
}
-/// Given:
-/// 1. an induction variable `iv` of type AffineForOp;
-/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
-/// determines whether `memoryOp` has a contiguous access along `iv`. Contiguous
-/// is defined as either invariant or varying only along a unique MemRef dim.
-/// Upon success, the unique MemRef dim is written in `memRefDim` (or -1 to
-/// convey the memRef access is invariant along `iv`).
-///
-/// Prerequisites:
-/// 1. `memRefDim` ~= nullptr;
-/// 2. `iv` of the proper type;
-/// 3. the MemRef accessed by `memoryOp` has no layout map or at most an
-/// identity layout map.
-///
-/// Currently only supports no layoutMap or identity layoutMap in the MemRef.
-/// Returns false if the MemRef has a non-identity layoutMap or more than 1
-/// layoutMap. This is conservative.
-///
-// TODO: check strides.
+// TODO: check access stride.
template <typename LoadOrStoreOp>
-static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
- int *memRefDim) {
- static_assert(
- llvm::is_one_of<LoadOrStoreOp, AffineLoadOp, AffineStoreOp>::value,
- "Must be called on either LoadOp or StoreOp");
+bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
+ int *memRefDim) {
+ static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
+ AffineWriteOpInterface>::value,
+ "Must be called on either an affine read or write op");
assert(memRefDim && "memRefDim == nullptr");
auto memRefType = memoryOp.getMemRefType();
if (!memRefType.getLayout().isIdentity())
- return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
+ return memoryOp.emitError("NYI: non-trivial layout map"), false;
int uniqueVaryingIndexAlongIv = -1;
auto accessMap = memoryOp.getAffineMap();
SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands());
unsigned numDims = accessMap.getNumDims();
for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) {
- // Gather map operands used result expr 'i' in 'exprOperands'.
+ // Gather map operands used in result expr 'i' in 'exprOperands'.
SmallVector<Value, 4> exprOperands;
auto resultExpr = accessMap.getResult(i);
resultExpr.walk([&](AffineExpr expr) {
@@ -241,7 +223,7 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
exprOperands.push_back(mapOperands[numDims + symExpr.getPosition()]);
});
// Check access invariance of each operand in 'exprOperands'.
- for (auto exprOperand : exprOperands) {
+ for (Value exprOperand : exprOperands) {
if (!isAccessIndexInvariant(iv, exprOperand)) {
if (uniqueVaryingIndexAlongIv != -1) {
// 2+ varying indices -> do not vectorize along iv.
@@ -259,6 +241,13 @@ static bool isContiguousAccess(Value iv, LoadOrStoreOp memoryOp,
return true;
}
+template bool mlir::affine::isContiguousAccess(Value iv,
+ AffineReadOpInterface loadOp,
+ int *memRefDim);
+template bool mlir::affine::isContiguousAccess(Value iv,
+ AffineWriteOpInterface loadOp,
+ int *memRefDim);
+
template <typename LoadOrStoreOp>
static bool isVectorElement(LoadOrStoreOp memoryOp) {
auto memRefType = memoryOp.getMemRefType();
@@ -344,10 +333,13 @@ bool mlir::affine::isVectorizableLoopBody(
auto load = dyn_cast<AffineLoadOp>(op);
auto store = dyn_cast<AffineStoreOp>(op);
int thisOpMemRefDim = -1;
- bool isContiguous = load ? isContiguousAccess(loop.getInductionVar(), load,
- &thisOpMemRefDim)
- : isContiguousAccess(loop.getInductionVar(), store,
- &thisOpMemRefDim);
+ bool isContiguous =
+ load ? isContiguousAccess(loop.getInductionVar(),
+ cast<AffineReadOpInterface>(*load),
+ &thisOpMemRefDim)
+ : isContiguousAccess(loop.getInductionVar(),
+ cast<AffineWriteOpInterface>(*store),
+ &thisOpMemRefDim);
if (thisOpMemRefDim != -1) {
// If memory accesses vary across different dimensions then the loop is
// not vectorizable.
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 8deb8f0..7f246da 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -261,68 +261,62 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
}
- Type i1Ty = b.getI1Type();
Type i16Ty = b.getI16Type();
Type i32Ty = b.getI32Type();
Type f32Ty = b.getF32Type();
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
- i1Ty = shapedTy.clone(i1Ty);
i16Ty = shapedTy.clone(i16Ty);
i32Ty = shapedTy.clone(i32Ty);
f32Ty = shapedTy.clone(f32Ty);
}
- Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
-
- Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
- Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
- Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
- Value expMask =
- createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
- Value expMax =
- createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
-
- // Grab the sign bit.
- Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
-
- // Our mantissa rounding value depends on the sign bit and the last
- // truncated bit.
- Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
- cManRound = b.create<arith::SubIOp>(cManRound, sign);
-
- // Grab out the mantissa and directly apply rounding.
- Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
- Value manRound = b.create<arith::AddIOp>(man, cManRound);
-
- // Grab the overflow bit and shift right if we overflow.
- Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
- Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
-
- // Grab the exponent and round using the mantissa's carry bit.
- Value exp = b.create<arith::AndIOp>(bitcast, expMask);
- Value expCarry = b.create<arith::AddIOp>(exp, manRound);
- expCarry = b.create<arith::AndIOp>(expCarry, expMask);
-
- // If the exponent is saturated, we keep the max value.
- Value expCmp =
- b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
- exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
-
- // If the exponent is max and we rolled over, keep the old mantissa.
- Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
- Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
- man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
-
- // Assemble the now rounded f32 value (as an i32).
- Value rounded = b.create<arith::ShLIOp>(sign, c31);
- rounded = b.create<arith::OrIOp>(rounded, exp);
- rounded = b.create<arith::OrIOp>(rounded, man);
-
+ // Algorithm borrowed from this excellent code:
+ // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
+ // There is a magic idea there, to let the addition of the rounding_bias to
+ // the mantissa simply overflow into the exponent bits. It's a bit of an
+ // aggressive, obfuscating optimization, but it is well-tested code, and it
+ // results in more concise and efficient IR.
+ // The case of NaN is handled separately (see isNaN and the final select).
+ // The case of infinities is NOT handled separately, which deserves an
+ // explanation. As the encoding of infinities has zero mantissa, the
+ // rounding-bias addition never carries into the exponent so that just gets
+ // truncated away, and as bfloat16 and float32 have the same number of
+ // exponent bits, that simple truncation is the desired outcome for
+ // infinities.
+ Value isNan =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
+ // Constant used to make the rounding bias.
+ Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
+ // Constant used to generate a quiet NaN.
+ Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
+ // Small constants used to address bits.
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
- Value shr = b.create<arith::ShRUIOp>(rounded, c16);
- Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
- Value result = b.create<arith::BitcastOp>(resultTy, trunc);
-
+ Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
+ // Reinterpret the input f32 value as bits.
+ Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
+ // Read bit 16 as a value in {0,1}.
+ Value bit16 =
+ b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
+ // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
+ // on bit 16, implementing the tie-breaking "to nearest even".
+ Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
+ // Add the rounding bias. Generally we want this to be added to the
+ // mantissa, but nothing prevents this to from carrying into the exponent
+ // bits, which would feel like a bug, but this is the magic trick here:
+ // when that happens, the mantissa gets reset to zero and the exponent
+ // gets incremented by the carry... which is actually exactly what we
+ // want.
+ Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
+ // Now that the rounding-bias has been added, truncating the low bits
+ // yields the correctly rounded result.
+ Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
+ Value normalCaseResult_i16 =
+ b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
+ // Select either the above-computed result, or a quiet NaN constant
+ // if the input was NaN.
+ Value select =
+ b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
+ Value result = b.create<arith::BitcastOp>(resultTy, select);
rewriter.replaceOp(op, result);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 69c3413..232635c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1445,6 +1445,38 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
return {};
}
+template <typename ToBufferOp>
+static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
+ OpaqueProperties prop,
+ RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+ typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ Type elemTp = nullptr;
+ bool withStride = false;
+ if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
+ elemTp = stt.getPosType();
+ } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
+ std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
+ elemTp = stt.getCrdType();
+ if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
+ withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
+ } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
+ elemTp = stt.getElementType();
+ }
+
+ assert(elemTp && "unhandled operation.");
+ SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
+ bufShape.push_back(ShapedType::kDynamic);
+
+ auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
+ stt.getContext(), ShapedType::kDynamic,
+ {ShapedType::kDynamic})
+ : StridedLayoutAttr();
+ ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
+ return success();
+}
+
LogicalResult ToPositionsOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1454,6 +1486,14 @@ LogicalResult ToPositionsOp::verify() {
return success();
}
+LogicalResult
+ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ValueRange ops, DictionaryAttr attr,
+ OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+ return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
+}
+
LogicalResult ToCoordinatesOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1463,6 +1503,14 @@ LogicalResult ToCoordinatesOp::verify() {
return success();
}
+LogicalResult
+ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ValueRange ops, DictionaryAttr attr,
+ OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+ return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
+}
+
LogicalResult ToCoordinatesBufferOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (stt.getAoSCOOStart() >= stt.getLvlRank())
@@ -1470,6 +1518,14 @@ LogicalResult ToCoordinatesBufferOp::verify() {
return success();
}
+LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+ return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
+ ret);
+}
+
LogicalResult ToValuesOp::verify() {
auto stt = getSparseTensorType(getTensor());
auto mtp = getMemRefType(getResult());
@@ -1478,6 +1534,15 @@ LogicalResult ToValuesOp::verify() {
return success();
}
+LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
+ std::optional<Location> loc,
+ ValueRange ops, DictionaryAttr attr,
+ OpaqueProperties prop,
+ RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+ return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
+}
+
LogicalResult ToSliceOffsetOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cdee8a4..cb75f6a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -496,11 +496,11 @@ static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
- return genToCoordinates(builder, loc, a, 0);
- return genToCoordinatesBuffer(builder, loc, a);
+ return builder.create<ToCoordinatesOp>(loc, a, 0);
+ return builder.create<ToCoordinatesBufferOp>(loc, a);
}
// Formats CSR/CSC and BSR use positions at 1.
- return genToPositions(builder, loc, a, 1);
+ return builder.create<ToPositionsOp>(loc, a, 1);
}
/// Generates the second coordinates of a sparse matrix.
@@ -510,7 +510,7 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
if (isCOO && !enableRT)
return Value(); // nothing needed
// Formats CSR/CSC and BSR use coordinates at 1.
- return genToCoordinates(builder, loc, a, 1);
+ return builder.create<ToCoordinatesOp>(loc, a, 1);
}
/// Generates the sparse matrix handle.
@@ -584,7 +584,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
- Value memV = genToValues(rewriter, loc, a);
+ Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -682,7 +682,7 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value memR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, a, format, enableRT); // or empty
- Value memV = genToValues(rewriter, loc, a);
+ Value memV = rewriter.create<ToValuesOp>(loc, a);
Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -785,10 +785,10 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
Value amemR = genFirstPosOrCrds(rewriter, loc, a, format, enableRT);
Value amemC = genSecondCrds(rewriter, loc, a, format, enableRT); // not empty
- Value amemV = genToValues(rewriter, loc, a);
+ Value amemV = rewriter.create<ToValuesOp>(loc, a);
Value bmemR = genFirstPosOrCrds(rewriter, loc, b, format, enableRT);
Value bmemC = genSecondCrds(rewriter, loc, b, format, enableRT); // not empty
- Value bmemV = genToValues(rewriter, loc, b);
+ Value bmemV = rewriter.create<ToValuesOp>(loc, b);
Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
@@ -1081,7 +1081,7 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
Value memR = genFirstPosOrCrds(rewriter, loc, c, format, enableRT);
Value memC = genSecondCrds(rewriter, loc, c, format, enableRT); // or empty
- Value memV = genToValues(rewriter, loc, c);
+ Value memV = rewriter.create<ToValuesOp>(loc, c);
Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
Value valC = genAllocCopy(rewriter, loc, memV, tokens);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index d5eec4a..4e33931 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1058,17 +1058,9 @@ public:
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
- Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
- Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
-
- // Insert a cast to bridge the actual type to the user expected type. If the
- // actual type and the user expected type aren't compatible, the compiler or
- // the runtime will issue an error.
- Type resType = op.getResult().getType();
- if (resType != field.getType())
- field = rewriter.create<memref::CastOp>(loc, resType, field);
- rewriter.replaceOp(op, field);
+ rewriter.replaceOp(
+ op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 1bcc131..6ff2146 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -21,9 +21,11 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
@@ -598,6 +600,101 @@ public:
}
};
+/// Sparse rewriting rule for the print operator. This operation is mainly used
+/// for debugging and testing. As such, it lowers to the vector.print operation
+/// which only require very light-weight runtime support.
+struct PrintRewriter : public OpRewritePattern<PrintOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(PrintOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto tensor = op.getTensor();
+ auto stt = getSparseTensorType(tensor);
+ // Header with NSE.
+ auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
+ rewriter.create<vector::PrintOp>(
+ loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
+ rewriter.create<vector::PrintOp>(loc, nse);
+ // Use the "codegen" foreach loop construct to iterate over
+ // all typical sparse tensor components for printing.
+ foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
+ &stt](Type, FieldIndex,
+ SparseTensorFieldKind kind,
+ Level l, LevelType) {
+ switch (kind) {
+ case SparseTensorFieldKind::StorageSpec: {
+ break;
+ }
+ case SparseTensorFieldKind::PosMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
+ printContents(rewriter, loc, pos);
+ break;
+ }
+ case SparseTensorFieldKind::CrdMemRef: {
+ auto lvl = constantIndex(rewriter, loc, l);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
+ rewriter.create<vector::PrintOp>(
+ loc, lvl, vector::PrintPunctuation::NoPunctuation);
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
+ Value crd = nullptr;
+ // TODO: eliminates ToCoordinateBufferOp!
+ if (stt.getAoSCOOStart() == l)
+ crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
+ else
+ crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
+ printContents(rewriter, loc, crd);
+ break;
+ }
+ case SparseTensorFieldKind::ValMemRef: {
+ rewriter.create<vector::PrintOp>(loc,
+ rewriter.getStringAttr("values : "));
+ auto val = rewriter.create<ToValuesOp>(loc, tensor);
+ printContents(rewriter, loc, val);
+ break;
+ }
+ }
+ return true;
+ });
+ rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ // Helper to print contents of a single memref. Note that for the "push_back"
+ // vectors, this prints the full capacity, not just the size. This is done
+ // on purpose, so that clients see how much storage has been allocated in
+ // total. Contents of the extra capacity in the buffer may be uninitialized
+ // (unless the flag enable-buffer-initialization is set to true).
+ //
+ // Generates code to print:
+ // ( a0, a1, ... )
+ static void printContents(PatternRewriter &rewriter, Location loc,
+ Value vec) {
+ // Open bracket.
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
+ // For loop over elements.
+ auto zero = constantIndex(rewriter, loc, 0);
+ auto size = rewriter.create<memref::DimOp>(loc, vec, zero);
+ auto step = constantIndex(rewriter, loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto idx = forOp.getInductionVar();
+ auto val = rewriter.create<memref::LoadOp>(loc, vec, idx);
+ rewriter.create<vector::PrintOp>(loc, val, vector::PrintPunctuation::Comma);
+ rewriter.setInsertionPointAfter(forOp);
+ // Close bracket and end of line.
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
+ rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
+ }
+};
+
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
public:
@@ -1284,7 +1381,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
- GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
+ GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
+ patterns.getContext());
}
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index b888dfa..fa57015 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -554,41 +554,6 @@ sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
.getResult();
}
-Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
- Value tensor, Level lvl) {
- const auto srcTp = getSparseTensorType(tensor);
- const Type posTp = srcTp.getPosType();
- const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
- return builder.create<ToPositionsOp>(loc, memTp, tensor,
- builder.getIndexAttr(lvl));
-}
-
-Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
- Value tensor, Level lvl) {
- const auto srcTp = getSparseTensorType(tensor);
- const Type crdTp = srcTp.getCrdType();
- const Type memTp =
- get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getAoSCOOStart());
- return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
- builder.getIndexAttr(lvl));
-}
-
-Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
- Value tensor) {
- const auto srcTp = getSparseTensorType(tensor);
- const Type crdTp = srcTp.getCrdType();
- const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
- return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
-}
-
-Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
- Value tensor) {
- RankedTensorType srcTp = getRankedTensorType(tensor);
- Type valTp = get1DMemRefType(srcTp.getElementType(),
- /*withLayout=*/false);
- return builder.create<ToValuesOp>(loc, valTp, tensor);
-}
-
Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cc119bc..e8f6bd1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -228,17 +228,6 @@ void deallocDenseTensor(OpBuilder &builder, Location loc, Value buffer);
void sizesFromSrc(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
Location loc, Value src);
-/// Generates a 1D MemRefType with a dynamic size. When withLayout is set, the
-/// returned memref has a layout has unknown strides and offsets. Otherwise,
-/// a memref with a standard unit stride zero offset layout is returned.
-inline MemRefType get1DMemRefType(Type etp, bool withLayout) {
- auto layout = withLayout ? StridedLayoutAttr::StridedLayoutAttr::get(
- etp.getContext(), ShapedType::kDynamic,
- {ShapedType::kDynamic})
- : StridedLayoutAttr();
- return MemRefType::get(ShapedType::kDynamic, etp, layout);
-}
-
/// Scans to top of generated loop.
Operation *getTop(Operation *op);
@@ -281,22 +270,6 @@ void storeAll(OpBuilder &builder, Location loc, Value mem, ValueRange vs,
TypedValue<BaseMemRefType> genToMemref(OpBuilder &builder, Location loc,
Value tensor);
-/// Infers the result type and generates `ToPositionsOp`.
-Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
-
-/// Infers the result type and generates `ToCoordinatesOp`. If the
-/// level is within a COO region, the result type is a memref with unknown
-/// stride and offset. Otherwise, the result type is a memref without
-/// any specified layout.
-Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
- Level lvl);
-
-/// Infers the result type and generates `ToCoordinatesBufferOp`.
-Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
-
-/// Infers the result type and generates `ToValuesOp`.
-Value genToValues(OpBuilder &builder, Location loc, Value tensor);
-
/// Generates code to retrieve the values size for the sparse tensor.
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 0ead135..812c288 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -259,7 +259,7 @@ void LoopEmitter::initializeLoopEmit(
// Annotated sparse tensors.
// We also need the value buffer for all-dense annotated "sparse"
// tensors.
- valBuffer[t] = genToValues(builder, loc, tensor);
+ valBuffer[t] = builder.create<ToValuesOp>(loc, tensor);
}
// NOTE: we can also prepare for 0 lvl here in advance, this will hoist
// some loop preparation from tensor iteration, but will also (undesirably)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 011d814..8edacaa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1281,21 +1281,21 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
case LevelFormat::Batch:
llvm_unreachable("not implemented");
case LevelFormat::Compressed: {
- Value pos = genToPositions(b, l, t, lvl);
- Value crd = genToCoordinates(b, l, t, lvl);
+ Value pos = b.create<ToPositionsOp>(l, t, lvl);
+ Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<CompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::LooseCompressed: {
- Value pos = genToPositions(b, l, t, lvl);
- Value crd = genToCoordinates(b, l, t, lvl);
+ Value pos = b.create<ToPositionsOp>(l, t, lvl);
+ Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<LooseCompressedLevel>(tid, lvl, lt, sz, pos, crd);
}
case LevelFormat::Singleton: {
- Value crd = genToCoordinates(b, l, t, lvl);
+ Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::NOutOfM: {
- Value crd = genToCoordinates(b, l, t, lvl);
+ Value crd = b.create<ToCoordinatesOp>(l, t, lvl);
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
}
case LevelFormat::Undef:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e6efec1..fe2f250 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4012,15 +4012,17 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
- auto outerDimsPerm = packOp.getOuterDimsPerm();
+ SmallVector<int64_t> inverseOuterDimsPerm;
+ if (!packOp.getOuterDimsPerm().empty())
+ inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
int srcRank = packOp.getSourceRank();
for (auto i : llvm::seq<int64_t>(0, srcRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
- if (!outerDimsPerm.empty())
- destPos = outerDimsPerm[srcPos];
+ if (!inverseOuterDimsPerm.empty())
+ destPos = inverseOuterDimsPerm[srcPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
@@ -4240,15 +4242,17 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
op.getDestType().getShape().end());
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
- auto outerDimsPerm = op.getOuterDimsPerm();
+ SmallVector<int64_t> inverseOuterDimsPerm;
+ if (!op.getOuterDimsPerm().empty())
+ inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
int destRank = op.getDestRank();
for (auto i : llvm::seq<int64_t>(0, destRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
- if (!outerDimsPerm.empty())
- srcPos = outerDimsPerm[destPos];
+ if (!inverseOuterDimsPerm.empty())
+ srcPos = inverseOuterDimsPerm[destPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 2ba3dec..16aa136 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -627,6 +627,33 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::LogicalAndOp logicalAndOp) {
+ Operation *operation = logicalAndOp.getOperation();
+ return printBinaryOperation(emitter, operation, "&&");
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::LogicalNotOp logicalNotOp) {
+ raw_ostream &os = emitter.ostream();
+
+ if (failed(emitter.emitAssignPrefix(*logicalNotOp.getOperation())))
+ return failure();
+
+ os << "!";
+
+ if (failed(emitter.emitOperand(logicalNotOp.getOperand())))
+ return failure();
+
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::LogicalOrOp logicalOrOp) {
+ Operation *operation = logicalOrOp.getOperation();
+ return printBinaryOperation(emitter, operation, "||");
+}
+
static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
@@ -1284,7 +1311,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
- emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+ emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
+ emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.