diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 32 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 74 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 105 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 62 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 52 | ||||
-rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/IR/Diagnostics.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/TableGen/CodeGenHelpers.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/DebugImporter.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Tools/PDLL/Parser/Parser.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp | 2 |
15 files changed, 300 insertions, 76 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 71687b1..ddcbc44 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 7ca09d9..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 5edcc40b..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() { return success(); } +LogicalResult ConvertF32x2ToF4x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f32x2 to f4x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } @@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair +ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(op.getA())); + args.push_back(mt.lookupValue(op.getB())); + + bool hasRelu = op.getRelu(); + + llvm::Intrinsic::ID intId = + hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite + : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite; + + return {intId, std::move(args)}; +} + #define GET_F32x2_TO_F6x2_ID(type, has_relu) \ has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite @@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +/// Verify the range attribute satisfies LLVM ConstantRange constructor +/// requirements for NVVM SpecialRangeableRegisterOp. +static LogicalResult +verifyConstantRangeAttr(Operation *op, + std::optional<LLVM::ConstantRangeAttr> rangeAttr) { + if (!rangeAttr) + return success(); + + const llvm::APInt &lower = rangeAttr->getLower(); + const llvm::APInt &upper = rangeAttr->getUpper(); + + // Check LLVM ConstantRange constructor condition + if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { + unsigned bitWidth = lower.getBitWidth(); + llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); + llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); + return op->emitOpError( + "invalid range attribute: Lower == Upper, but they aren't min (") + << llvm::toString(minVal, 10, false) << ") or max (" + << llvm::toString(maxVal, 10, false) + << ") value! This is an invalid constant range."; + } + + return success(); +} + static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c477c6c..dcc1ef9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody( Value yielded = getSourceSkipUnary(terminator->getOperand(0)); Operation *reductionOp = yielded.getDefiningOp(); - if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) { + if (!reductionOp || reductionOp->getNumResults() != 1 || + reductionOp->getNumOperands() != 2) { errs << "expected reduction op to be binary"; return false; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0..45c54c7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// ```mlir +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// ``` +/// +/// as, +/// +/// ```mlir +/// %out = arith.constant dense<false> : vector<3xi1>. +/// ``` +/// +/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result +/// is false at ALL indices we fold. If the constant was 1, then +/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold, +/// conservatively preferring the 'compact' vector.step representation. +/// +/// Note: this folder only works for the case where the constant (`%cst` above) +/// is the second operand of the comparison. The arith.cmpi canonicalizer will +/// ensure that constants are always second (on the right). +struct StepCompareFolder : public OpRewritePattern<StepOp> { + using Base::Base; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (OpOperand &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + std::optional<int64_t> maybeConstValue = + getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional<bool> { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); + if (!type) + continue; + + auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<StepCompareFolder>(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index e95338f..12e6475 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern { // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching // yielded value, and: - // 1. recording the unique first position at which the value is yielded. + // 1. recording the unique first position at which the value with uses is + // yielded. // 2. recording for the result, the first position at which the dedup'ed // value is yielded. // 3. skipping from the new result types / new yielded values any result // that has no use or whose yielded value has already been seen. for (OpResult result : warpOp.getResults()) { + if (result.use_empty()) + continue; Value yieldOperand = yield.getOperand(result.getResultNumber()); auto it = dedupYieldOperandPositionMap.insert( std::make_pair(yieldOperand, newResultTypes.size())); dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); - if (result.use_empty() || !it.second) + if (!it.second) continue; newResultTypes.push_back(result.getType()); newYieldValues.push_back(yieldOperand); @@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(), escapingValueDistTypesElse.end()); - llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx; for (auto [idx, val] : llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) { - origToNewYieldIdx[idx] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(val); newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType()); } - // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + // Replace the old `WarpOp` with the new one that has additional yield + // values and types. + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // `ifOp` returns the result of the inner warp op. SmallVector<Type> newIfOpDistResTypes; for (auto [i, res] : llvm::enumerate(ifOp.getResults())) { @@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newWarpOp); auto newIfOp = scf::IfOp::create( - rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0), - static_cast<bool>(ifOp.thenBlock()), + rewriter, ifOp.getLoc(), newIfOpDistResTypes, + newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()), static_cast<bool>(ifOp.elseBlock())); auto encloseRegionInWarpOp = [&](Block *oldIfBranch, Block *newIfBranch, @@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { for (size_t i = 0; i < escapingValues.size(); ++i, ++warpResRangeStart) { innerWarpInputVals.push_back( - newWarpOp.getResult(warpResRangeStart)); + newWarpOp.getResult(newIndices[warpResRangeStart])); escapeValToBlockArgIndex[escapingValues[i]] = innerWarpInputTypes.size(); innerWarpInputTypes.push_back(escapingValueInputTypes[i]); @@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern { // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp` // result. for (auto [origIdx, newIdx] : ifResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newIfOp.getResult(newIdx), newIfOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `IfOp`. - for (auto [origIdx, newIdx] : origToNewYieldIdx) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `IfOp`, they should not have any uses - // at this point. - rewriter.eraseOp(ifOp); - rewriter.eraseOp(warpOp); return success(); } @@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern { escapingValueDistTypes.begin(), escapingValueDistTypes.end()); // Next, we insert all non-`ForOp` yielded values and their distributed - // types. We also create a mapping between the non-`ForOp` yielded value - // index and the corresponding new `WarpOp` yield value index (needed to - // update users later). - llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping; + // types. for (auto [i, v] : llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) { - nonForResultMapping[i] = newWarpOpYieldValues.size(); newWarpOpYieldValues.push_back(v); newWarpOpDistTypes.push_back(warpOp.getResult(i).getType()); } // Create the new `WarpOp` with the updated yield values and types. - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes); + SmallVector<size_t> newIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices); // Next, we create a new `ForOp` with the init args yielded by the new // `WarpOp`. @@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { // escaping values in the new `WarpOp`. SmallVector<Value> newForOpOperands; for (size_t i = 0; i < escapingValuesStartIdx; ++i) - newForOpOperands.push_back(newWarpOp.getResult(i)); + newForOpOperands.push_back(newWarpOp.getResult(newIndices[i])); // Create a new `ForOp` outside the new `WarpOp` region. OpBuilder::InsertionGuard g(rewriter); @@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { llvm::SmallDenseMap<Value, int64_t> argIndexMapping; for (size_t i = escapingValuesStartIdx; i < escapingValuesStartIdx + escapingValues.size(); ++i) { - innerWarpInput.push_back(newWarpOp.getResult(i)); + innerWarpInput.push_back(newWarpOp.getResult(newIndices[i])); argIndexMapping[escapingValues[i - escapingValuesStartIdx]] = innerWarpInputType.size(); innerWarpInputType.push_back( @@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern { if (!innerWarp.getResults().empty()) scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults()); - // Update the users of original `WarpOp` results that were coming from the + // Update the users of the new `WarpOp` results that were coming from the // original `ForOp` to the corresponding new `ForOp` result. for (auto [origIdx, newIdx] : forResultMapping) - rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx), + rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx), newForOp.getResult(newIdx), newForOp); - // Similarly, update any users of the `WarpOp` results that were not - // results of the `ForOp`. - for (auto [origIdx, newIdx] : nonForResultMapping) - rewriter.replaceAllUsesWith(warpOp.getResult(origIdx), - newWarpOp.getResult(newIdx)); - // Remove the original `WarpOp` and `ForOp`, they should not have any uses - // at this point. - rewriter.eraseOp(forOp); - rewriter.eraseOp(warpOp); // Update any users of escaping values that were forwarded to the // inner `WarpOp`. These values are now arguments of the inner `WarpOp`. newForOp.walk([&](Operation *op) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 14639c5..fbae098 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); + int64_t targetShapeRank = targetShape->size(); auto dstVecType = cast<VectorType>(op->getResult(0).getType()); SmallVector<int64_t> originalSize = *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); - // Bail-out if rank(source) != rank(target). The main limitation here is the - // fact that `ExtractStridedSlice` requires the rank for the input and - // output to match. If needed, we can relax this later. - if (originalSize.size() != targetShape->size()) - return rewriter.notifyMatchFailure( - op, "expected input vector rank to match target shape rank"); + int64_t originalShapeRank = originalSize.size(); + Location loc = op->getLoc(); + + // Handle rank mismatch by adding leading unit dimensions to targetShape + SmallVector<int64_t> adjustedTargetShape(originalShapeRank); + int64_t rankDiff = originalShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), + adjustedTargetShape.begin() + rankDiff, 1); + std::copy(targetShape->begin(), targetShape->end(), + adjustedTargetShape.begin() + rankDiff); + + int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); // Prepare the result vector. Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector<int64_t> strides(targetShape->size(), 1); - VectorType newVecType = + SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); + VectorType unrolledVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector<int64_t> offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { + StaticTileOffsetRange(originalSize, adjustedTargetShape)) { SmallVector<Value> extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast<VectorType>(operand.get().getType()); @@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern { extractOperands.push_back(operand.get()); continue; } - extractOperands.push_back( - rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, operand.get(), offsets, *targetShape, strides)); + Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand.get(), offsets, adjustedTargetShape, strides); + + // Reshape to remove leading unit dims if needed + if (adjustedTargetShapeRank > targetShapeRank) { + extracted = rewriter.createOrFold<vector::ShapeCastOp>( + loc, VectorType::get(*targetShape, vecType.getElementType()), + extracted); + } + extractOperands.push_back(extracted); } + Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); + rewriter, loc, op, extractOperands, unrolledVecType); + + Value computeResult = newOp->getResult(0); + + // Use strides sized to targetShape for proper insertion + SmallVector<int64_t> insertStrides = + (adjustedTargetShapeRank > targetShapeRank) + ? SmallVector<int64_t>(targetShapeRank, 1) + : strides; + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( - loc, newOp->getResult(0), result, offsets, strides); + loc, computeResult, result, offsets, insertStrides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 81b5788..e0a8ac4 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -20,8 +20,8 @@ #define DEBUG_TYPE "xegpu" -namespace mlir { -namespace xegpu { +using namespace mlir; +using namespace mlir::xegpu; static bool isSharedMemory(const MemRefType &memrefTy) { Attribute attr = memrefTy.getMemorySpace(); @@ -1133,9 +1133,6 @@ LogicalResult MemDescSubviewOp::verify() { return success(); } -} // namespace xegpu -} // namespace mlir - namespace mlir { #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc> } // namespace mlir diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..4d81918 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - std::string ignored; - unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); + auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); + if (!bufferOrErr) + return 0; + unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp index cb90ef8..d52d5e7 100644 --- a/mlir/lib/TableGen/CodeGenHelpers.cpp +++ b/mlir/lib/TableGen/CodeGenHelpers.cpp @@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( raw_ostream &os, const RecordKeeper &records, StringRef tag) : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} -void StaticVerifierFunctionEmitter::emitOpConstraints( - ArrayRef<const Record *> opDefs) { - NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); +void StaticVerifierFunctionEmitter::emitOpConstraints() { emitTypeConstraints(); emitAttrConstraints(); emitPropConstraints(); diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 4bbcd8e..db39c70 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) { return UnknownLoc::get(context); // Add a fused location to link the subprogram information. - StringAttr funcName = StringAttr::get(context, subprogram->getName()); StringAttr fileName = StringAttr::get(context, subprogram->getFilename()); return FusedLocWith<DISubprogramAttr>::get( - {NameLoc::get(funcName), - FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, + {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)}, translate(subprogram), context); } diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index c883baa..3236b4f 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" #include <optional> @@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, llvm::SourceMgr tdSrcMgr; tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); + tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); // This class provides a context argument for the llvm::SourceMgr diagnostic // handler. diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 60b9567..1dbe7eca 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include <optional> using namespace mlir; @@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, llvm::append_range(includeDirs, extraDirs); sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 3080b78..2d817be 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" #include <optional> @@ -448,6 +449,7 @@ void TableGenTextFile::initialize( return; } sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); // This class provides a context argument for the SourceMgr diagnostic |