diff options
Diffstat (limited to 'mlir/lib/Dialect')
31 files changed, 1171 insertions, 163 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp index e0c3abe..82a9fb0 100644 --- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp @@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) { mapOperands.push_back(value1); mapOperands.push_back(value2); affine::fullyComposeAffineMapAndOperands(&map, &mapOperands); - ValueDimList valueDims; - for (Value v : mapOperands) - valueDims.push_back({v, std::nullopt}); return ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::EQ, map, valueDims); + presburger::BoundType::EQ, + ValueBoundsConstraintSet::Variable(map, mapOperands)); } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 117ee8e..1a266b7 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -16,16 +16,15 @@ using namespace mlir; using namespace mlir::affine; -static FailureOr<OpFoldResult> -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional<int64_t> dim, - ValueBoundsConstraintSet::StopConditionFn stopCondition, - bool closedUB) { +FailureOr<OpFoldResult> mlir::affine::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, var, stopCondition, closedUB))) return failure(); // Reify bound. @@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound( // the owner of `value`. return v != value; }; - return reifyValueBound(b, loc, type, value, dim, + return reifyValueBound(b, loc, type, {value, dim}, stopCondition ? stopCondition : reifyToOperands, closedUB); } @@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound( ValueBoundsConstraintSet &cstr) { return v != value; }; - return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + return reifyValueBound(b, loc, type, value, stopCondition ? stopCondition : reifyToOperands, closedUB); } diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp index f0d4380..7cfcc41 100644 --- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp @@ -107,9 +107,9 @@ struct SelectOpInterface // If trueValue <= falseValue: // * result <= falseValue // * result >= trueValue - if (cstr.compare(trueValue, dim, + if (cstr.compare(/*lhs=*/{trueValue, dim}, ValueBoundsConstraintSet::ComparisonOperator::LE, - falseValue, dim)) { + /*rhs=*/{falseValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim); @@ -121,9 +121,9 @@ struct SelectOpInterface // If falseValue <= trueValue: // * result <= trueValue // * result >= falseValue - if (cstr.compare(falseValue, dim, + if (cstr.compare(/*lhs=*/{falseValue, dim}, ValueBoundsConstraintSet::ComparisonOperator::LE, - trueValue, dim)) { + /*rhs=*/{trueValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 79fabd6..f87f3d6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> { return failure(); FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, in, /*dim=*/std::nullopt, + presburger::BoundType::UB, in, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(ub)) return failure(); diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index fad2212..5fb7953 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, return buildExpr(map.getResult(0)); } -static FailureOr<OpFoldResult> -reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type, - Value value, std::optional<int64_t> dim, - ValueBoundsConstraintSet::StopConditionFn stopCondition, - bool closedUB) { +FailureOr<OpFoldResult> mlir::arith::reifyValueBound( + OpBuilder &b, Location loc, presburger::BoundType type, + const ValueBoundsConstraintSet::Variable &var, + ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { // Compute bound. AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeBound( - boundMap, mapOperands, type, value, dim, stopCondition, closedUB))) + boundMap, mapOperands, type, var, stopCondition, closedUB))) return failure(); // Materialize tensor.dim/memref.dim ops. @@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound( // the owner of `value`. return v != value; }; - return reifyValueBound(b, loc, type, value, dim, + return reifyValueBound(b, loc, type, {value, dim}, stopCondition ? stopCondition : reifyToOperands, closedUB); } @@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound( ValueBoundsConstraintSet &cstr) { return v != value; }; - return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, + return reifyValueBound(b, loc, type, value, stopCondition ? stopCondition : reifyToOperands, closedUB); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 31500c6..b595c6d 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) { return (vectorRows * vectorCols) / (minNumElts * minNumElts); } +/// Legalize `arith.constant dense<value>` splat operations to fit within SME +/// tiles by decomposing them into tile-sized operations. +struct LegalizeArithConstantOpsByDecomposition + : public OneToNOpConversionPattern<arith::ConstantOp> { + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, + OneToNPatternRewriter &rewriter) const override { + auto vectorType = dyn_cast<VectorType>(constantOp.getType()); + auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); + if (!vectorType || !denseAttr || !denseAttr.isSplat()) + return failure(); + + if (!isMultipleOfSMETileVectorType(vectorType)) + return rewriter.notifyMatchFailure(constantOp, + kMatchFailureNotSMETileTypeMultiple); + + auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); + auto tileCount = getNumberOfSMETilesForVectorType(vectorType); + auto tileSplat = rewriter.create<arith::ConstantOp>( + constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); + rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat), + adaptor.getResultMapping()); + + return success(); + } +}; + /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition @@ -637,7 +666,8 @@ struct VectorLegalizationPass // Note: High benefit to ensure masked outer products are lowered first. patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>( converter, context, 1024); - patterns.add<LegalizeVectorOuterProductOpsByDecomposition, + patterns.add<LegalizeArithConstantOpsByDecomposition, + LegalizeVectorOuterProductOpsByDecomposition, LegalizeTransferReadOpsByDecomposition, LegalizeTransferWriteOpsByDecomposition>(converter, context); populateFuncTypeConversionPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index b1ba5a3..a324ce7 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon) add_subdirectory(OpenMP) add_subdirectory(PDL) add_subdirectory(PDLInterp) +add_subdirectory(Polynomial) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index 8c4b70d..518d2e1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, // Otherwise, try to compute a constant upper bound for the size value. FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, opOperand->get(), - /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); + presburger::BoundType::UB, + {opOperand->get(), + /*dim=*/i}, + /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding"); return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index ac896d6..71eb59d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) { size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); } else { - Value materializedSize = - getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); FailureOr<int64_t> upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt, + presburger::BoundType::UB, rangeValue.size, /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) - ? materializedSize + ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) : b.create<arith::ConstantIndexOp>(loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 2578565..df61381 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp, /// Create a TransferReadOp from `source` with static shape `readShape`. If the /// vector type for the read is not the same as the type of `source`, then a -/// mask is created on the read. +/// mask is created on the read. If `doMasking` parameter is set to false we +/// update the `inBounds` attribute instead of masking. static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef<int64_t> readShape, - Value padValue) { + Value padValue, bool doMasking = true) { assert(llvm::none_of(readShape, [](int64_t s) { return s == ShapedType::kDynamic; })); auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape(); @@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc, auto vectorType = VectorType::get(readShape, padValue.getType()); int64_t readRank = readShape.size(); auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); + SmallVector<bool> inBoundsVal(readRank, true); + if (!doMasking) { + // Update the inBounds attribute. + for (unsigned i = 0; i < readRank; i++) + inBoundsVal[i] = sourceShape[i] == readShape[i]; + } auto transferReadOp = builder.create<vector::TransferReadOp>( loc, /*vectorType=*/vectorType, /*source=*/source, /*indices=*/SmallVector<Value>(readRank, zero), /*padding=*/padValue, - /*inBounds=*/SmallVector<bool>(readRank, true)); - if (llvm::equal(readShape, sourceShape)) { + /*inBounds=*/inBoundsVal); + + if (llvm::equal(readShape, sourceShape) || !doMasking) { return transferReadOp; } SmallVector<OpFoldResult> mixedSourceDims = @@ -1482,11 +1490,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, return write; } -/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant -/// padding value into: +/// Vectorize tensor::PackOp with (1) static innerTiles (2) constant +/// padding value and (3) input vector sizes into: /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds /// As in the following example: -/// /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> /// @@ -1505,6 +1512,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc, /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] /// {in_bounds = [true, true, true, true, true]} /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +/// +/// If the (3) input vector sizes are not provided, the vector sizes are +/// determined by the result tensor shape. Also, we update the inBounds +/// attribute instead of masking. static LogicalResult vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, ArrayRef<int64_t> inputVectorSizes, @@ -1525,6 +1536,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, (void)status; // prevent unused variable warning on non-assert builds. assert(succeeded(status) && "failed to reify result shapes"); + // If the input vector sizes are not provided, then the vector sizes are + // determined by the result tensor shape. In case the vector sizes aren't + // provided, we update the inBounds attribute instead of masking. + bool doMasking = true; + if (inputVectorSizes.empty()) { + ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); + inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank()); + doMasking = false; + } + // Create masked TransferReadOp. SmallVector<int64_t> inputShape(inputVectorSizes); auto innerTiles = packOp.getStaticInnerTiles(); @@ -1536,7 +1557,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp, for (auto [idx, size] : enumerate(innerTiles)) inputShape[innerDimsPos[idx]] *= size; auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(), - inputShape, padValue); + inputShape, padValue, doMasking); // Create ShapeCastOp. SmallVector<int64_t> destShape(inputVectorSizes); @@ -1763,7 +1784,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: /// 1. The numbers of elements in both array are equal. -/// 2. `inputVectorSizes` does nos have dynamic dimensions. +/// 2. `inputVectorSizes` does not have dynamic dimensions. /// 3. All the values in `inputVectorSizes` are greater than or equal to /// static sizes in `shape`. static LogicalResult @@ -1881,18 +1902,25 @@ static LogicalResult vectorizeLinalgOpPrecondition( return success(); } -/// TODO: Use a matcher to check for a constant padding value. static LogicalResult vectorizePackOpPrecondition(tensor::PackOp packOp, ArrayRef<int64_t> inputVectorSizes) { auto padValue = packOp.getPaddingValue(); - if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) { + Attribute cstAttr; + if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { LDBG("pad value is not constant: " << packOp << "\n"); return failure(); } - ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); - if (failed(isValidMaskedInputVector( + bool satisfyEmptyCond = true; + if (inputVectorSizes.empty()) { + if (!packOp.getDestType().hasStaticShape() || + !packOp.getSourceType().hasStaticShape()) + satisfyEmptyCond = false; + } + + if (!satisfyEmptyCond && + failed(isValidMaskedInputVector( resultTensorShape.take_front(packOp.getSourceRank()), inputVectorSizes))) return failure(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 10ba508..1f06318 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, ValueRange independencies) { if (ofr.is<Attribute>()) return ofr; - Value value = ofr.get<Value>(); AffineMap boundMap; ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( - boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies, + /*closedUB=*/true))) return failure(); return affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 35fb174..5d2281c 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1548,90 +1548,41 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region ®ion, p.printRegion(region, /*printEntryBlockArgs=*/false); } -/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds -/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps -/// steps := `step` `(`ssa-id-list`)` -ParseResult -parseLoopControl(OpAsmParser &parser, Region ®ion, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound, - SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps, - SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) { - // Parse an opening `(` followed by induction variables followed by `)` - SmallVector<OpAsmParser::Argument> ivs; - Type loopVarType; - if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || - parser.parseColonType(loopVarType) || - // Parse loop bounds. - parser.parseEqual() || - parser.parseOperandList(lowerBound, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.parseKeyword("to") || - parser.parseOperandList(upperBound, ivs.size(), - OpAsmParser::Delimiter::Paren)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("inclusive"))) - inclusive = UnitAttr::get(parser.getBuilder().getContext()); - - // Parse step values. - if (parser.parseKeyword("step") || - parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) - return failure(); - - // Now parse the body. - loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType); - for (auto &iv : ivs) - iv.type = loopVarType; - - return parser.parseRegion(region, ivs); -} - -void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, - ValueRange lowerBound, ValueRange upperBound, - ValueRange steps, TypeRange loopVarTypes, - UnitAttr inclusive) { - auto args = region.front().getArguments(); - p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound - << ") to (" << upperBound << ") "; - if (inclusive) - p << "inclusive "; - p << "step (" << steps << ") "; - p.printRegion(region, /*printEntryBlockArgs=*/false); -} - //===----------------------------------------------------------------------===// // Simd construct [2.9.3.1] //===----------------------------------------------------------------------===// -void SimdLoopOp::build(OpBuilder &builder, OperationState &state, - const SimdLoopClauseOps &clauses) { +void SimdOp::build(OpBuilder &builder, OperationState &state, + const SimdClauseOps &clauses) { MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars, // privatizers, reductionDeclSymbols. - SimdLoopOp::build( - builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, - clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs), - clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr, - clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr); + SimdOp::build(builder, state, clauses.alignedVars, + makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar, + clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr, + clauses.safelenAttr); } -LogicalResult SimdLoopOp::verify() { - if (this->getLowerBound().empty()) { - return emitOpError() << "empty lowerbound for simd loop operation"; - } - if (this->getSimdlen().has_value() && this->getSafelen().has_value() && - this->getSimdlen().value() > this->getSafelen().value()) { +LogicalResult SimdOp::verify() { + if (getSimdlen().has_value() && getSafelen().has_value() && + getSimdlen().value() > getSafelen().value()) return emitOpError() << "simdlen clause and safelen clause are both present, but the " "simdlen value is not less than or equal to safelen value"; - } - if (verifyAlignedClause(*this, this->getAlignmentValues(), - this->getAlignedVars()) + + if (verifyAlignedClause(*this, getAlignmentValues(), getAlignedVars()) .failed()) return failure(); - if (verifyNontemporalClause(*this, this->getNontemporalVars()).failed()) + + if (verifyNontemporalClause(*this, getNontemporalVars()).failed()) return failure(); + + if (!isWrapper()) + return emitOpError() << "must be a loop wrapper"; + + if (getNestedWrapper()) + return emitOpError() << "must wrap an 'omp.loop_nest' directly"; + return success(); } @@ -1656,6 +1607,17 @@ LogicalResult DistributeOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); + if (!isWrapper()) + return emitOpError() << "must be a loop wrapper"; + + if (LoopWrapperInterface nested = getNestedWrapper()) { + // Check for the allowed leaf constructs that may appear in a composite + // construct directly after DISTRIBUTE. + if (!isa<ParallelOp, SimdOp>(nested)) + return emitError() << "only supported nested wrappers are 'omp.parallel' " + "and 'omp.simd'"; + } + return success(); } @@ -1818,9 +1780,8 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. TaskloopOp::build( - builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar, - clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar, - clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars, + builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr, + clauses.mergeableAttr, clauses.inReductionVars, makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars, makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar, clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar, @@ -1859,6 +1820,16 @@ LogicalResult TaskloopOp::verify() { "the grainsize clause and num_tasks clause are mutually exclusive and " "may not appear on the same taskloop directive"); } + + if (!isWrapper()) + return emitOpError() << "must be a loop wrapper"; + + if (LoopWrapperInterface nested = getNestedWrapper()) { + // Check for the allowed leaf constructs that may appear in a composite + // construct directly after TASKLOOP. + if (!isa<SimdOp>(nested)) + return emitError() << "only supported nested wrapper is 'omp.simd'"; + } return success(); } @@ -1936,9 +1907,27 @@ LogicalResult LoopNestOp::verify() { << "range argument type does not match corresponding IV type"; } + auto wrapper = + llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); + + if (!wrapper || !wrapper.isWrapper()) + return emitOpError() << "expects parent op to be a valid loop wrapper"; + return success(); } +void LoopNestOp::gatherWrappers( + SmallVectorImpl<LoopWrapperInterface> &wrappers) { + Operation *parent = (*this)->getParentOp(); + while (auto wrapper = + llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) { + if (!wrapper.isWrapper()) + break; + wrappers.push_back(wrapper); + parent = parent->getParentOp(); + } +} + //===----------------------------------------------------------------------===// // WsloopOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Polynomial/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/CMakeLists.txt new file mode 100644 index 0000000..f33061b2 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt new file mode 100644 index 0000000..7f5b325 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_dialect_library(MLIRPolynomialDialect + Polynomial.cpp + PolynomialAttributes.cpp + PolynomialDialect.cpp + PolynomialOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polynomial + + DEPENDS + MLIRPolynomialIncGen + MLIRPolynomialAttributesIncGen + MLIRBuiltinAttributesIncGen + + LINK_LIBS PUBLIC + MLIRSupport + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp new file mode 100644 index 0000000..5916ffb --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp @@ -0,0 +1,96 @@ +//===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace polynomial { + +FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) { + // A polynomial's terms are canonically stored in order of increasing degree. + auto monomialsCopy = llvm::SmallVector<Monomial>(monomials); + std::sort(monomialsCopy.begin(), monomialsCopy.end()); + + // Ensure non-unique exponents are not present. Since we sorted the list by + // exponent, a linear scan of adjancent monomials suffices. + if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(), + [](const Monomial &lhs, const Monomial &rhs) { + return lhs.exponent == rhs.exponent; + }) != monomialsCopy.end()) { + return failure(); + } + + return Polynomial(monomialsCopy); +} + +Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) { + llvm::SmallVector<Monomial> monomials; + auto size = coeffs.size(); + monomials.reserve(size); + for (size_t i = 0; i < size; i++) { + monomials.emplace_back(coeffs[i], i); + } + auto result = Polynomial::fromMonomials(monomials); + // Construction guarantees unique exponents, so the failure mode of + // fromMonomials can be bypassed. + assert(succeeded(result)); + return result.value(); +} + +void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator, + ::llvm::StringRef exponentiation) const { + bool first = true; + for (const Monomial &term : terms) { + if (first) { + first = false; + } else { + os << separator; + } + std::string coeffToPrint; + if (term.coefficient == 1 && term.exponent.uge(1)) { + coeffToPrint = ""; + } else { + llvm::SmallString<16> coeffString; + term.coefficient.toStringSigned(coeffString); + coeffToPrint = coeffString.str(); + } + + if (term.exponent == 0) { + os << coeffToPrint; + } else if (term.exponent == 1) { + os << coeffToPrint << "x"; + } else { + llvm::SmallString<16> expString; + term.exponent.toStringSigned(expString); + os << coeffToPrint << "x" << exponentiation << expString; + } + } +} + +void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); } + +std::string Polynomial::toIdentifier() const { + std::string result; + llvm::raw_string_ostream os(result); + print(os, "_", ""); + return os.str(); +} + +unsigned Polynomial::getDegree() const { + return terms.back().exponent.getZExtValue(); +} + +} // namespace polynomial +} // namespace mlir diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp new file mode 100644 index 0000000..ee09c73 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp @@ -0,0 +1,213 @@ +//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" + +namespace mlir { +namespace polynomial { + +void PolynomialAttr::print(AsmPrinter &p) const { + p << '<'; + p << getPolynomial(); + p << '>'; +} + +/// Try to parse a monomial. If successful, populate the fields of the outparam +/// `monomial` with the results, and the `variable` outparam with the parsed +/// variable name. Sets shouldParseMore to true if the monomial is followed by +/// a '+'. +ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, + llvm::StringRef &variable, bool &isConstantTerm, + bool &shouldParseMore) { + APInt parsedCoeff(apintBitWidth, 1); + auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff); + monomial.coefficient = parsedCoeff; + + isConstantTerm = false; + shouldParseMore = false; + + // A + indicates it's a constant term with more to go, as in `1 + x`. + if (succeeded(parser.parseOptionalPlus())) { + // If no coefficient was parsed, and there's a +, then it's effectively + // parsing an empty string. + if (!parsedCoeffResult.has_value()) { + return failure(); + } + monomial.exponent = APInt(apintBitWidth, 0); + isConstantTerm = true; + shouldParseMore = true; + return success(); + } + + // A monomial can be a trailing constant term, as in `x + 1`. + if (failed(parser.parseOptionalKeyword(&variable))) { + // If neither a coefficient nor a variable was found, then it's effectively + // parsing an empty string. + if (!parsedCoeffResult.has_value()) { + return failure(); + } + + monomial.exponent = APInt(apintBitWidth, 0); + isConstantTerm = true; + return success(); + } + + // Parse exponentiation symbol as `**`. We can't use caret because it's + // reserved for basic block identifiers If no star is present, it's treated + // as a polynomial with exponent 1. + if (succeeded(parser.parseOptionalStar())) { + // If there's one * there must be two. + if (failed(parser.parseStar())) { + return failure(); + } + + // If there's a **, then the integer exponent is required. + APInt parsedExponent(apintBitWidth, 0); + if (failed(parser.parseInteger(parsedExponent))) { + parser.emitError(parser.getCurrentLocation(), + "found invalid integer exponent"); + return failure(); + } + + monomial.exponent = parsedExponent; + } else { + monomial.exponent = APInt(apintBitWidth, 1); + } + + if (succeeded(parser.parseOptionalPlus())) { + shouldParseMore = true; + } + return success(); +} + +Attribute PolynomialAttr::parse(AsmParser &parser, Type type) { + if (failed(parser.parseLess())) + return {}; + + llvm::SmallVector<Monomial> monomials; + llvm::StringSet<> variables; + + while (true) { + Monomial parsedMonomial; + llvm::StringRef parsedVariableRef; + bool isConstantTerm; + bool shouldParseMore; + if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef, + isConstantTerm, shouldParseMore))) { + parser.emitError(parser.getCurrentLocation(), "expected a monomial"); + return {}; + } + + if (!isConstantTerm) { + std::string parsedVariable = parsedVariableRef.str(); + variables.insert(parsedVariable); + } + monomials.push_back(parsedMonomial); + + if (shouldParseMore) + continue; + + if (succeeded(parser.parseOptionalGreater())) { + break; + } + parser.emitError( + parser.getCurrentLocation(), + "expected + and more monomials, or > to end polynomial attribute"); + return {}; + } + + if (variables.size() > 1) { + std::string vars = llvm::join(variables.keys(), ", "); + parser.emitError( + parser.getCurrentLocation(), + "polynomials must have one indeterminate, but there were multiple: " + + vars); + } + + auto result = Polynomial::fromMonomials(monomials); + if (failed(result)) { + parser.emitError(parser.getCurrentLocation()) + << "parsed polynomial must have unique exponents among monomials"; + return {}; + } + return PolynomialAttr::get(parser.getContext(), result.value()); +} + +void RingAttr::print(AsmPrinter &p) const { + p << "#polynomial.ring<coefficientType=" << getCoefficientType() + << ", coefficientModulus=" << getCoefficientModulus() + << ", polynomialModulus=" << getPolynomialModulus() << '>'; +} + +Attribute RingAttr::parse(AsmParser &parser, Type type) { + if (failed(parser.parseLess())) + return {}; + + if (failed(parser.parseKeyword("coefficientType"))) + return {}; + + if (failed(parser.parseEqual())) + return {}; + + Type ty; + if (failed(parser.parseType(ty))) + return {}; + + if (failed(parser.parseComma())) + return {}; + + IntegerAttr coefficientModulusAttr = nullptr; + if (succeeded(parser.parseKeyword("coefficientModulus"))) { + if (failed(parser.parseEqual())) + return {}; + + IntegerType iType = ty.dyn_cast<IntegerType>(); + if (!iType) { + parser.emitError(parser.getCurrentLocation(), + "coefficientType must specify an integer type"); + return {}; + } + APInt coefficientModulus(iType.getWidth(), 0); + auto result = parser.parseInteger(coefficientModulus); + if (failed(result)) { + parser.emitError(parser.getCurrentLocation(), + "invalid coefficient modulus"); + return {}; + } + coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus); + + if (failed(parser.parseComma())) + return {}; + } + + PolynomialAttr polyAttr = nullptr; + if (succeeded(parser.parseKeyword("polynomialModulus"))) { + if (failed(parser.parseEqual())) + return {}; + + PolynomialAttr attr; + if (failed(parser.parseAttribute<PolynomialAttr>(attr))) + return {}; + polyAttr = attr; + } + + if (failed(parser.parseGreater())) + return {}; + + return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr, + polyAttr); +} + +} // namespace polynomial +} // namespace mlir diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp new file mode 100644 index 0000000..a672a59 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp @@ -0,0 +1,41 @@ +//===- PolynomialDialect.cpp - Polynomial dialect ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h" +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::polynomial; + +#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc" +#define GET_OP_CLASSES +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" + +void PolynomialDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp new file mode 100644 index 0000000..96c59a2 --- /dev/null +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -0,0 +1,15 @@ +//===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" + +using namespace mlir; +using namespace mlir::polynomial; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 087ffc4..17a1c01 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -61,12 +61,13 @@ struct ForOpInterface // An EQ constraint can be added if the yielded value (dimension size) // equals the corresponding block argument (dimension size). if (cstr.populateAndCompare( - yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ, - iterArg, dim)) { + /*lhs=*/{yieldedValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::EQ, + /*rhs=*/{iterArg, dim})) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { - cstr.bound(value) == initArg; + cstr.bound(value) == cstr.getExpr(initArg); } } } @@ -113,8 +114,9 @@ struct IfOpInterface // * result <= elseValue // * result >= thenValue if (cstr.populateAndCompare( - thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { + /*lhs=*/{thenValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{elseValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); @@ -127,8 +129,9 @@ struct IfOpInterface // * result <= thenValue // * result >= elseValue if (cstr.populateAndCompare( - elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { + /*lhs=*/{elseValue, dim}, + ValueBoundsConstraintSet::ComparisonOperator::LE, + /*rhs=*/{thenValue, dim})) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index e549420..a2925ae 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms StructuralTypeConversions.cpp TileUsingInterface.cpp WrapInZeroTripCheck.cpp + UpliftWhileToFor.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp new file mode 100644 index 0000000..7b4024b --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -0,0 +1,214 @@ +//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.WhileOp's into SCF.ForOp's. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { +struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp loop, + PatternRewriter &rewriter) const override { + return upliftWhileToForLoop(rewriter, loop); + } +}; +} // namespace + +FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, + scf::WhileOp loop) { + Block *beforeBody = loop.getBeforeBody(); + if (!llvm::hasSingleElement(beforeBody->without_terminator())) + return rewriter.notifyMatchFailure(loop, "Loop body must have single op"); + + auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front()); + if (!cmp) + return rewriter.notifyMatchFailure(loop, + "Loop body must have single cmp op"); + + scf::ConditionOp beforeTerm = loop.getConditionOp(); + if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Expected single condition use: " << *cmp; + }); + + // All `before` block args must be directly forwarded to ConditionOp. + // They will be converted to `scf.for` `iter_vars` except induction var. + if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) + return rewriter.notifyMatchFailure(loop, "Invalid args order"); + + using Pred = arith::CmpIPredicate; + Pred predicate = cmp.getPredicate(); + if (predicate != Pred::slt && predicate != Pred::sgt) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; + }); + + BlockArgument inductionVar; + Value ub; + DominanceInfo dom; + + // Check if cmp has a suitable form. One of the arguments must be a `before` + // block arg, other must be defined outside `scf.while` and will be treated + // as upper bound. + for (bool reverse : {false, true}) { + auto expectedPred = reverse ? Pred::sgt : Pred::slt; + if (cmp.getPredicate() != expectedPred) + continue; + + auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); + auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); + + auto blockArg = dyn_cast<BlockArgument>(arg1); + if (!blockArg || blockArg.getOwner() != beforeBody) + continue; + + if (!dom.properlyDominates(arg2, loop)) + continue; + + inductionVar = blockArg; + ub = arg2; + break; + } + + if (!inductionVar) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Unrecognized cmp form: " << *cmp; + }); + + // inductionVar must have 2 uses: one is in `cmp` and other is `condition` + // arg. + if (!llvm::hasNItems(inductionVar.getUses(), 2)) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Unrecognized induction var: " << inductionVar; + }); + + Block *afterBody = loop.getAfterBody(); + scf::YieldOp afterTerm = loop.getYieldOp(); + unsigned argNumber = inductionVar.getArgNumber(); + Value afterTermIndArg = afterTerm.getResults()[argNumber]; + + Value inductionVarAfter = afterBody->getArgument(argNumber); + + // Find suitable `addi` op inside `after` block, one of the args must be an + // Induction var passed from `before` block and second arg must be defined + // outside of the loop and will be considered step value. + // TODO: Add `subi` support? + auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>(); + if (!addOp) + return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); + + Value step; + if (addOp.getLhs() == inductionVarAfter) { + step = addOp.getRhs(); + } else if (addOp.getRhs() == inductionVarAfter) { + step = addOp.getLhs(); + } + + if (!step || !dom.properlyDominates(step, loop)) + return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form"); + + Value lb = loop.getInits()[argNumber]; + + assert(lb.getType().isIntOrIndex()); + assert(lb.getType() == ub.getType()); + assert(lb.getType() == step.getType()); + + llvm::SmallVector<Value> newArgs; + + // Populate inits for new `scf.for`, skip induction var. + newArgs.reserve(loop.getInits().size()); + for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { + if (i == argNumber) + continue; + + newArgs.emplace_back(init); + } + + Location loc = loop.getLoc(); + + // With `builder == nullptr`, ForOp::build will try to insert terminator at + // the end of newly created block and we don't want it. Provide empty + // dummy builder instead. + auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; + auto newLoop = + rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); + + Block *newBody = newLoop.getBody(); + + // Populate block args for `scf.for` body, move induction var to the front. + newArgs.clear(); + ValueRange newBodyArgs = newBody->getArguments(); + for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) { + if (i < argNumber) { + newArgs.emplace_back(newBodyArgs[i + 1]); + } else if (i == argNumber) { + newArgs.emplace_back(newBodyArgs.front()); + } else { + newArgs.emplace_back(newBodyArgs[i]); + } + } + + rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), + newArgs); + + auto term = cast<scf::YieldOp>(newBody->getTerminator()); + + // Populate new yield args, skipping the induction var. + newArgs.clear(); + for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { + if (i == argNumber) + continue; + + newArgs.emplace_back(arg); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(term); + rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs); + + // Compute induction var value after loop execution. + rewriter.setInsertionPointAfter(newLoop); + Value one; + if (isa<IndexType>(step.getType())) { + one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + } else { + one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType()); + } + + Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); + Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); + len = rewriter.create<arith::AddIOp>(loc, len, stepDec); + len = rewriter.create<arith::DivSIOp>(loc, len, step); + len = rewriter.create<arith::SubIOp>(loc, len, one); + Value res = rewriter.create<arith::MulIOp>(loc, len, step); + res = rewriter.create<arith::AddIOp>(loc, lb, res); + + // Reconstruct `scf.while` results, inserting final induction var value + // into proper place. + newArgs.clear(); + llvm::append_range(newArgs, newLoop.getResults()); + newArgs.insert(newArgs.begin() + argNumber, res); + rewriter.replaceOp(loop, newArgs); + return newLoop; +} + +void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { + patterns.add<UpliftWhileOp>(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp index 5b7c0a5..bbc318e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -120,17 +120,16 @@ bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; } -spirv::EntryPointABIAttr -spirv::getEntryPointABIAttr(MLIRContext *context, - ArrayRef<int32_t> workgroupSize, - std::optional<int> subgroupSize) { +spirv::EntryPointABIAttr spirv::getEntryPointABIAttr( + MLIRContext *context, ArrayRef<int32_t> workgroupSize, + std::optional<int> subgroupSize, std::optional<int> targetWidth) { DenseI32ArrayAttr workgroupSizeAttr; if (!workgroupSize.empty()) { assert(workgroupSize.size() == 3); workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize); } - return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, - subgroupSize); + return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize, + targetWidth); } spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 6150b5e..2024a2e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase workgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), DenseI32ArrayAttr(), - entryPointAttr.getSubgroupSize()); + entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth()); } } if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) { @@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, // Erase subgroup size. entryPointAttr = spirv::EntryPointABIAttr::get( entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), - std::nullopt); + std::nullopt, entryPointAttr.getTargetWidth()); } } - if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize()) + if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) { + std::optional<ArrayRef<spirv::Capability>> caps = + spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve); + if (!caps || targetEnv.allows(*caps)) { + builder.create<spirv::ExecutionModeOp>( + funcOp.getLoc(), funcOp, + spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth); + // Erase target width. + entryPointAttr = spirv::EntryPointABIAttr::get( + entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(), + entryPointAttr.getSubgroupSize(), std::nullopt); + } + } + if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() || + entryPointAttr.getTargetWidth()) funcOp->setAttr(entryPointAttrName, entryPointAttr); else funcOp->removeAttr(entryPointAttrName); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index e905839..516b094 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -30,6 +30,14 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" +// Forward declarations, following custom print/parsing methods are referenced +// by the generated code for SparseTensorTypes.td. +static mlir::ParseResult parseLevelRange(mlir::AsmParser &, + mlir::sparse_tensor::Level &, + mlir::sparse_tensor::Level &); +static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, + mlir::sparse_tensor::Level); + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" @@ -1953,6 +1961,108 @@ LogicalResult SortOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Sparse Tensor Iteration Operations. +//===----------------------------------------------------------------------===// + +IterSpaceType IteratorType::getIterSpaceType() const { + return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), + getHiLvl()); +} + +IteratorType IterSpaceType::getIteratorType() const { + return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); +} + +/// Parses a level range in the form "$lo `to` $hi" +/// or simply "$lo" if $hi - $lo = 1 +static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, + Level &lvlHi) { + if (parser.parseInteger(lvlLo)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + if (parser.parseInteger(lvlHi)) + return failure(); + } else { + lvlHi = lvlLo + 1; + } + + if (lvlHi <= lvlLo) + parser.emitError(parser.getNameLoc(), + "expect larger level upper bound than lower bound"); + + return success(); +} + +/// Parses a level range in the form "$lo `to` $hi" +/// or simply "$lo" if $hi - $lo = 1 +static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, + IntegerAttr &lvlHiAttr) { + Level lvlLo, lvlHi; + if (parseLevelRange(parser, lvlLo, lvlHi)) + return failure(); + + lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); + lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); + return success(); +} + +/// Prints a level range in the form "$lo `to` $hi" +/// or simply "$lo" if $hi - $lo = 1 +static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { + + if (lo + 1 == hi) + p << lo; + else + p << lo << " to " << hi; +} + +/// Prints a level range in the form "$lo `to` $hi" +/// or simply "$lo" if $hi - $lo = 1 +static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, + IntegerAttr lvlHi) { + unsigned lo = lvlLo.getValue().getZExtValue(); + unsigned hi = lvlHi.getValue().getZExtValue(); + printLevelRange(p, lo, hi); +} + +LogicalResult ExtractIterSpaceOp::inferReturnTypes( + MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, + DictionaryAttr attr, OpaqueProperties prop, RegionRange region, + SmallVectorImpl<mlir::Type> &ret) { + + ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); + SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); + ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), + adaptor.getHiLvl())); + return success(); +} + +LogicalResult ExtractIterSpaceOp::verify() { + if (getLoLvl() >= getHiLvl()) + return emitOpError("expected smaller level low than level high"); + + TypedValue<IteratorType> pIter = getParentIter(); + if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { + return emitOpError( + "parent iterator should be specified iff level lower bound equals 0"); + } + + if (pIter) { + IterSpaceType spaceTp = getResultSpace().getType(); + if (pIter.getType().getEncoding() != spaceTp.getEncoding()) + return emitOpError( + "mismatch in parent iterator encoding and iteration space encoding."); + + if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) + return emitOpError("parent iterator should be used to extract an " + "iteration space from a consecutive level."); + } + + return success(); +} + /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index f497be6..3a89720 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 67080d8..d25efcf 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp, info.isAlignedToInnerTileSize = false; FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, - getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt, + presburger::BoundType::UB, tileSize, /*stopCondition=*/nullptr, /*closedUB=*/true); std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize); if (!failed(cstSize) && cstInnerSize) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 72173086..a89ce20 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc, ValueDimList mapOperands; if (failed(ValueBoundsConstraintSet::computeIndependentBound( boundMap, mapOperands, presburger::BoundType::UB, value, - /*dim=*/std::nullopt, independencies, /*closedUB=*/true))) + independencies, + /*closedUB=*/true))) return failure(); return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands); } diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 2dd91e2..15381ec 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { continue; } FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), srcDim, resultDim); + {op.getSource(), srcDim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++srcDim; @@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { continue; } FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( - op.getSource(), op.getResult(), dim, resultDim); + {op.getSource(), dim}, {op.getResult(), resultDim}); if (failed(equalDimSize) || !*equalDimSize) return false; ++resultDim; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index dc19022..53f958c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -396,6 +396,13 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( static_cast<RewriterBase::Listener *>(rewriter.getListener()); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1) + ? GreedyRewriteConfig::kNoLimit + : getMaxIterations(); + config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1) + ? GreedyRewriteConfig::kNoLimit + : getMaxNumRewrites(); + // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE // was requested, apply the greedy pattern rewrite only once. (The greedy // pattern rewrite driver already iterates to a fixpoint internally.) diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 0b3f4b9..24719fe 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -32,6 +32,17 @@ void XeGPUDialect::initialize() { //===----------------------------------------------------------------------===// // XeGPU_TensorDescAttr //===----------------------------------------------------------------------===// +TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context, + xegpu::MemoryScope memory_scope, + int array_length, bool boundary_check, + bool scattered) { + auto scopeAttr = MemoryScopeAttr::get(context, memory_scope); + auto lengthAttr = + IntegerAttr::get(IntegerType::get(context, 64), array_length); + auto boundaryAttr = BoolAttr::get(context, boundary_check); + auto scatteredAttr = BoolAttr::get(context, scattered); + return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr); +} //===----------------------------------------------------------------------===// // XeGPU_TensorDescType @@ -96,6 +107,16 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const { printer << ">"; } +TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, + mlir::Type elementType, bool scattered, + int array_length, MemoryScope memory_scope, + bool boundary_check) { + auto context = elementType.getContext(); + auto attr = TensorDescAttr::get(context, memory_scope, array_length, + boundary_check, scattered); + return Base::get(context, shape, elementType, attr); +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 02106f2..530c50e 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -9,6 +9,9 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/TypeUtilities.h" + +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "xegpu" @@ -16,8 +19,8 @@ namespace mlir { namespace xegpu { static void transpose(llvm::ArrayRef<int64_t> trans, - std::vector<int64_t> &shape) { - std::vector<int64_t> old = shape; + SmallVector<int64_t> &shape) { + SmallVector<int64_t> old = shape; for (size_t i = 0; i < trans.size(); i++) shape[i] = old[trans[i]]; } @@ -38,6 +41,38 @@ static std::string makeString(T array, bool breakline = false) { return buf; } +static SmallVector<int64_t> getShapeOf(Type type) { + SmallVector<int64_t> shape; + if (auto ty = llvm::dyn_cast<ShapedType>(type)) + shape = SmallVector<int64_t>(ty.getShape()); + else + shape.push_back(1); + return shape; +} + +static int64_t getRankOf(Value val) { + auto type = val.getType(); + if (auto ty = llvm::dyn_cast<ShapedType>(type)) + return ty.getRank(); + return 0; +} + +static bool isReadHintOrNone(const CachePolicyAttr &attr) { + if (!attr) + return true; + auto kind = attr.getValue(); + return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || + kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE; +} + +static bool isWriteHintOrNone(const CachePolicyAttr &attr) { + if (!attr) + return true; + auto kind = attr.getValue(); + return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED || + kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH; +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -114,6 +149,29 @@ LogicalResult CreateNdDescOp::verify() { return emitOpError("TensorDesc should have the same element " "type with the source if it is a memref.\n"); + if (getType().getScattered()) + return emitOpError("Expects a non-scattered TensorDesc.\n"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_PrefetchNdOp +//===----------------------------------------------------------------------===// +LogicalResult PrefetchNdOp::verify() { + auto tdescTy = getTensorDescType(); + if (tdescTy.getScattered()) + return emitOpError("Expects a non-scattered TensorDesc.\n"); + + if (!isReadHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isReadHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isReadHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + return success(); } @@ -125,22 +183,26 @@ LogicalResult LoadNdOp::verify() { auto valueTy = getType(); if (tdescTy.getRank() != 2) - return emitOpError( - "The TensorDesc for LoadNdOp should be a 2D TensorDesc."); + return emitOpError("Expecting a 2D TensorDesc.\n"); + + if (tdescTy.getScattered()) + return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!valueTy) return emitOpError("Invalid result, it should be a VectorType.\n"); - auto tdescElemTy = tdescTy.getElementType(); - auto valueElemTy = valueTy.getElementType(); + if (!isReadHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); - if (tdescElemTy != valueElemTy) - return emitOpError( - "Value should have the same element type as TensorDesc."); + if (!isReadHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isReadHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); auto array_len = tdescTy.getArrayLength(); - auto tdescShape = tdescTy.getShape().vec(); - auto valueShape = valueTy.getShape().vec(); + auto tdescShape = getShapeOf(tdescTy); + auto valueShape = getShapeOf(valueTy); if (getTranspose()) { auto trans = getTranspose().value(); @@ -174,26 +236,174 @@ LogicalResult LoadNdOp::verify() { // XeGPU_StoreNdOp //===----------------------------------------------------------------------===// LogicalResult StoreNdOp::verify() { - auto dstTy = getTensorDesc().getType(); // Tile - auto valTy = getValue().getType().cast<VectorType>(); // Vector + auto dstTy = getTensorDescType(); // Tile + auto valTy = getValueType(); // Vector if (dstTy.getRank() != 2) - return emitOpError("Expecting a 2D TensorDesc shape.\n"); + return emitOpError("Expecting a 2D TensorDesc.\n"); + + if (dstTy.getScattered()) + return emitOpError("Expects a non-scattered TensorDesc.\n"); if (!valTy) return emitOpError("Exepcting a VectorType result.\n"); - auto dstElemTy = dstTy.getElementType(); - auto valElemTy = valTy.getElementType(); + if (!isWriteHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isWriteHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isWriteHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + + return success(); +} - if (dstElemTy != valElemTy) { - return emitOpError() << "The element type of the value should " - "match the elementtype of the TensorDesc.\n"; +//===----------------------------------------------------------------------===// +// XeGPU_UpdateNDOffsetOp +//===----------------------------------------------------------------------===// +LogicalResult UpdateNdOffsetOp::verify() { + auto ty = getTensorDescType(); + if (ty.getScattered()) + return emitOpError("Expects a non-scattered TensorDesc.\n"); + + // number of offsets specified must match the rank of the tensor descriptor + if (ty.getRank() != (int64_t)getNumOffsets()) { + return emitOpError("Invalid number of offsets."); } + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_CreateDescOp +//===----------------------------------------------------------------------===// +void CreateDescOp::build(OpBuilder &builder, OperationState &state, + TensorDescType TensorDesc, Value source, + llvm::ArrayRef<OpFoldResult> offsets, + uint32_t chunk_size) { + llvm::SmallVector<int64_t> staticOffsets; + llvm::SmallVector<Value> dynamicOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets, + chunk_size); +} + +LogicalResult CreateDescOp::verify() { + auto tdescTy = getTensorDescType(); + auto chunkSize = getChunkSize(); + + if (getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + + if (!tdescTy.getScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + SmallVector<int64_t> shape({(int64_t)getNumOffsets()}); + if (chunkSize != 1) + shape.push_back(chunkSize); + + auto tdescShape = getShapeOf(tdescTy); + if (shape != tdescShape) + return emitOpError("Incorrect TensorDesc shape. ") + << "Expected is " << makeString(shape) << "\n"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_PrefetchOp +//===----------------------------------------------------------------------===// +LogicalResult PrefetchOp::verify() { + auto tdescTy = getTensorDescType(); + if (!tdescTy.getScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!isReadHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isReadHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isReadHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_LoadGatherOp +//===----------------------------------------------------------------------===// +LogicalResult LoadGatherOp::verify() { + auto tdescTy = getTensorDescType(); + auto maskTy = getMaskType(); + auto valueTy = getValueType(); + + if (!tdescTy.getScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!isReadHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isReadHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isReadHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + + auto tdescElemTy = tdescTy.getElementType(); + auto valueElemTy = getElementType(); + if (tdescElemTy != valueElemTy) + return emitOpError( + "Value should have the same element type as TensorDesc."); + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + auto tdescShape = getShapeOf(tdescTy); + + if (tdescShape[0] != maskShape[0]) + return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); + + if (getTransposeAttr()) { + auto trans = getTranspose().value(); + if (tdescShape.size() < trans.size()) + emitWarning("Invalid transpose attr. It is ignored."); + else + transpose(trans, tdescShape); + } + + if (valueShape != tdescShape) + return emitOpError("Unexpected result shape") + << "(Expected shape: " << makeString(tdescShape) + << ", Given shape: " << makeString(valueShape) << ").\n"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreScatterOp +//===----------------------------------------------------------------------===// +LogicalResult StoreScatterOp::verify() { + auto tdescTy = getTensorDescType(); + if (!tdescTy.getScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!isWriteHintOrNone(getL1HintAttr())) + return emitOpError("invlid l1_hint: ") << getL1HintAttr(); + + if (!isWriteHintOrNone(getL2HintAttr())) + return emitOpError("invlid l2_hint: ") << getL2HintAttr(); + + if (!isWriteHintOrNone(getL3HintAttr())) + return emitOpError("invlid l3_hint: ") << getL3HintAttr(); + + auto maskTy = getMaskType(); + auto maskShape = getShapeOf(maskTy); + auto tdescShape = getShapeOf(tdescTy); + if (tdescShape[0] != maskShape[0]) + return emitOpError("dim-0 of the Mask and TensorDesc should be the same."); - if (dstTy.getShape() != valTy.getShape()) - return emitOpError() - << "The result shape should match the TensorDesc shape.\n"; return success(); } |