aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp32
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp14
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp74
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp3
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp105
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp62
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp52
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp6
-rw-r--r--mlir/lib/TableGen/CodeGenHelpers.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp4
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp2
-rw-r--r--mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp2
-rw-r--r--mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp2
15 files changed, 300 insertions, 76 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 71687b1..ddcbc44 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,6 +20,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 7ca09d9..3eae67f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() {
return success();
}
+// Folding for shufflevector op when v1 is single element 1D vector
+// and the mask is a single zero. OpFoldResult will be v1 in this case.
+OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
+ // Check if operand 0 is a single element vector.
+ auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
+ return {};
+ // Check if the mask is a single zero.
+ // Note: The mask is guaranteed to be non-empty.
+ if (getMask().size() != 1 || getMask()[0] != 0)
+ return {};
+ return getV1();
+}
+
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 5edcc40b..2a8c330 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -787,6 +798,26 @@ LogicalResult MmaOp::verify() {
" attribute");
}
+ // Validate layout combinations. According to the operation description, most
+ // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16
+ // can use other layout combinations.
+ bool isM8N8K4_F16 =
+ (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 &&
+ getMultiplicandAPtxType() == MMATypes::f16);
+
+ if (!isM8N8K4_F16) {
+ // For all other shapes/types, layoutA must be row and layoutB must be col
+ if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) {
+ return emitOpError("requires layoutA = #nvvm.mma_layout<row> and "
+ "layoutB = #nvvm.mma_layout<col> for shape <")
+ << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2]
+ << "> with element types "
+ << stringifyEnum(*getMultiplicandAPtxType()) << " and "
+ << stringifyEnum(*getMultiplicandBPtxType())
+ << ". Only m8n8k4 with f16 supports other layouts.";
+ }
+ }
+
return success();
}
@@ -2047,6 +2078,23 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
@@ -2306,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+/// Verify the range attribute satisfies LLVM ConstantRange constructor
+/// requirements for NVVM SpecialRangeableRegisterOp.
+static LogicalResult
+verifyConstantRangeAttr(Operation *op,
+ std::optional<LLVM::ConstantRangeAttr> rangeAttr) {
+ if (!rangeAttr)
+ return success();
+
+ const llvm::APInt &lower = rangeAttr->getLower();
+ const llvm::APInt &upper = rangeAttr->getUpper();
+
+ // Check LLVM ConstantRange constructor condition
+ if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) {
+ unsigned bitWidth = lower.getBitWidth();
+ llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth);
+ llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth);
+ return op->emitOpError(
+ "invalid range attribute: Lower == Upper, but they aren't min (")
+ << llvm::toString(minVal, 10, false) << ") or max ("
+ << llvm::toString(maxVal, 10, false)
+ << ") value! This is an invalid constant range.";
+ }
+
+ return success();
+}
+
static llvm::Value *getAsPackedI32(llvm::Value *arg,
llvm::IRBuilderBase &builder) {
return builder.CreateBitCast(arg,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index c477c6c..dcc1ef9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -315,7 +315,8 @@ bool mlir::linalg::detail::isContractionBody(
Value yielded = getSourceSkipUnary(terminator->getOperand(0));
Operation *reductionOp = yielded.getDefiningOp();
- if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
+ if (!reductionOp || reductionOp->getNumResults() != 1 ||
+ reductionOp->getNumOperands() != 2) {
errs << "expected reduction op to be binary";
return false;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0..45c54c7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// ```mlir
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+/// ```
+///
+/// as,
+///
+/// ```mlir
+/// %out = arith.constant dense<false> : vector<3xi1>.
+/// ```
+///
+/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
+/// is false at ALL indices we fold. If the constant was 1, then
+/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
+/// conservatively preferring the 'compact' vector.step representation.
+///
+/// Note: this folder only works for the case where the constant (`%cst` above)
+/// is the second operand of the comparison. The arith.cmpi canonicalizer will
+/// ensure that constants are always second (on the right).
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (OpOperand &use : stepOp.getResult().getUses()) {
+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+ if (!cmpiOp)
+ continue;
+
+ // arith.cmpi canonicalizer makes constants final operands.
+ const unsigned stepOperandNumber = use.getOperandNumber();
+ if (stepOperandNumber != 0)
+ continue;
+
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ std::optional<int64_t> maybeConstValue =
+ getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize - 1 <= constValue) {
+ return pred == arith::CmpIPredicate::ule;
+ }
+
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
+
+ return std::nullopt;
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e95338f..12e6475 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -928,17 +928,20 @@ struct WarpOpDeadResult : public WarpDistributionPattern {
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
// yielded value, and:
- // 1. recording the unique first position at which the value is yielded.
+ // 1. recording the unique first position at which the value with uses is
+ // yielded.
// 2. recording for the result, the first position at which the dedup'ed
// value is yielded.
// 3. skipping from the new result types / new yielded values any result
// that has no use or whose yielded value has already been seen.
for (OpResult result : warpOp.getResults()) {
+ if (result.use_empty())
+ continue;
Value yieldOperand = yield.getOperand(result.getResultNumber());
auto it = dedupYieldOperandPositionMap.insert(
std::make_pair(yieldOperand, newResultTypes.size()));
dedupResultPositionMap.insert(std::make_pair(result, it.first->second));
- if (result.use_empty() || !it.second)
+ if (!it.second)
continue;
newResultTypes.push_back(result.getType());
newYieldValues.push_back(yieldOperand);
@@ -1843,16 +1846,16 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
newWarpOpDistTypes.append(escapingValueDistTypesElse.begin(),
escapingValueDistTypesElse.end());
- llvm::SmallDenseMap<unsigned, unsigned> origToNewYieldIdx;
for (auto [idx, val] :
llvm::zip_equal(nonIfYieldIndices, nonIfYieldValues)) {
- origToNewYieldIdx[idx] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(val);
newWarpOpDistTypes.push_back(warpOp.getResult(idx).getType());
}
- // Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ // Replace the old `WarpOp` with the new one that has additional yield
+ // values and types.
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// `ifOp` returns the result of the inner warp op.
SmallVector<Type> newIfOpDistResTypes;
for (auto [i, res] : llvm::enumerate(ifOp.getResults())) {
@@ -1870,8 +1873,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(newWarpOp);
auto newIfOp = scf::IfOp::create(
- rewriter, ifOp.getLoc(), newIfOpDistResTypes, newWarpOp.getResult(0),
- static_cast<bool>(ifOp.thenBlock()),
+ rewriter, ifOp.getLoc(), newIfOpDistResTypes,
+ newWarpOp.getResult(newIndices[0]), static_cast<bool>(ifOp.thenBlock()),
static_cast<bool>(ifOp.elseBlock()));
auto encloseRegionInWarpOp =
[&](Block *oldIfBranch, Block *newIfBranch,
@@ -1888,7 +1891,7 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
for (size_t i = 0; i < escapingValues.size();
++i, ++warpResRangeStart) {
innerWarpInputVals.push_back(
- newWarpOp.getResult(warpResRangeStart));
+ newWarpOp.getResult(newIndices[warpResRangeStart]));
escapeValToBlockArgIndex[escapingValues[i]] =
innerWarpInputTypes.size();
innerWarpInputTypes.push_back(escapingValueInputTypes[i]);
@@ -1936,17 +1939,8 @@ struct WarpOpScfIfOp : public WarpDistributionPattern {
// Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
// result.
for (auto [origIdx, newIdx] : ifResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newIfOp.getResult(newIdx), newIfOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `IfOp`.
- for (auto [origIdx, newIdx] : origToNewYieldIdx)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `IfOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(ifOp);
- rewriter.eraseOp(warpOp);
return success();
}
@@ -2065,19 +2059,16 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
escapingValueDistTypes.begin(),
escapingValueDistTypes.end());
// Next, we insert all non-`ForOp` yielded values and their distributed
- // types. We also create a mapping between the non-`ForOp` yielded value
- // index and the corresponding new `WarpOp` yield value index (needed to
- // update users later).
- llvm::SmallDenseMap<unsigned, unsigned> nonForResultMapping;
+ // types.
for (auto [i, v] :
llvm::zip_equal(nonForResultIndices, nonForYieldedValues)) {
- nonForResultMapping[i] = newWarpOpYieldValues.size();
newWarpOpYieldValues.push_back(v);
newWarpOpDistTypes.push_back(warpOp.getResult(i).getType());
}
// Create the new `WarpOp` with the updated yield values and types.
- WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
+ SmallVector<size_t> newIndices;
+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes, newIndices);
// Next, we create a new `ForOp` with the init args yielded by the new
// `WarpOp`.
@@ -2086,7 +2077,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
// escaping values in the new `WarpOp`.
SmallVector<Value> newForOpOperands;
for (size_t i = 0; i < escapingValuesStartIdx; ++i)
- newForOpOperands.push_back(newWarpOp.getResult(i));
+ newForOpOperands.push_back(newWarpOp.getResult(newIndices[i]));
// Create a new `ForOp` outside the new `WarpOp` region.
OpBuilder::InsertionGuard g(rewriter);
@@ -2110,7 +2101,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
for (size_t i = escapingValuesStartIdx;
i < escapingValuesStartIdx + escapingValues.size(); ++i) {
- innerWarpInput.push_back(newWarpOp.getResult(i));
+ innerWarpInput.push_back(newWarpOp.getResult(newIndices[i]));
argIndexMapping[escapingValues[i - escapingValuesStartIdx]] =
innerWarpInputType.size();
innerWarpInputType.push_back(
@@ -2146,20 +2137,11 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
if (!innerWarp.getResults().empty())
scf::YieldOp::create(rewriter, forOp.getLoc(), innerWarp.getResults());
- // Update the users of original `WarpOp` results that were coming from the
+ // Update the users of the new `WarpOp` results that were coming from the
// original `ForOp` to the corresponding new `ForOp` result.
for (auto [origIdx, newIdx] : forResultMapping)
- rewriter.replaceAllUsesExcept(warpOp.getResult(origIdx),
+ rewriter.replaceAllUsesExcept(newWarpOp.getResult(origIdx),
newForOp.getResult(newIdx), newForOp);
- // Similarly, update any users of the `WarpOp` results that were not
- // results of the `ForOp`.
- for (auto [origIdx, newIdx] : nonForResultMapping)
- rewriter.replaceAllUsesWith(warpOp.getResult(origIdx),
- newWarpOp.getResult(newIdx));
- // Remove the original `WarpOp` and `ForOp`, they should not have any uses
- // at this point.
- rewriter.eraseOp(forOp);
- rewriter.eraseOp(warpOp);
// Update any users of escaping values that were forwarded to the
// inner `WarpOp`. These values are now arguments of the inner `WarpOp`.
newForOp.walk([&](Operation *op) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5..fbae098 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern {
auto targetShape = getTargetShape(options, op);
if (!targetShape)
return failure();
+ int64_t targetShapeRank = targetShape->size();
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
SmallVector<int64_t> originalSize =
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
- // Bail-out if rank(source) != rank(target). The main limitation here is the
- // fact that `ExtractStridedSlice` requires the rank for the input and
- // output to match. If needed, we can relax this later.
- if (originalSize.size() != targetShape->size())
- return rewriter.notifyMatchFailure(
- op, "expected input vector rank to match target shape rank");
+ int64_t originalShapeRank = originalSize.size();
+
Location loc = op->getLoc();
+
+ // Handle rank mismatch by adding leading unit dimensions to targetShape
+ SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
+ int64_t rankDiff = originalShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(),
+ adjustedTargetShape.begin() + rankDiff, 1);
+ std::copy(targetShape->begin(), targetShape->end(),
+ adjustedTargetShape.begin() + rankDiff);
+
+ int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
// Prepare the result vector.
Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
rewriter.getZeroAttr(dstVecType));
- SmallVector<int64_t> strides(targetShape->size(), 1);
- VectorType newVecType =
+ SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
+ VectorType unrolledVecType =
VectorType::get(*targetShape, dstVecType.getElementType());
// Create the unrolled computation.
for (SmallVector<int64_t> offsets :
- StaticTileOffsetRange(originalSize, *targetShape)) {
+ StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
SmallVector<Value> extractOperands;
for (OpOperand &operand : op->getOpOperands()) {
auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
extractOperands.push_back(operand.get());
continue;
}
- extractOperands.push_back(
- rewriter.createOrFold<vector::ExtractStridedSliceOp>(
- loc, operand.get(), offsets, *targetShape, strides));
+ Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+ // Reshape to remove leading unit dims if needed
+ if (adjustedTargetShapeRank > targetShapeRank) {
+ extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, VectorType::get(*targetShape, vecType.getElementType()),
+ extracted);
+ }
+ extractOperands.push_back(extracted);
}
+
Operation *newOp = cloneOpWithOperandsAndTypes(
- rewriter, loc, op, extractOperands, newVecType);
+ rewriter, loc, op, extractOperands, unrolledVecType);
+
+ Value computeResult = newOp->getResult(0);
+
+ // Use strides sized to targetShape for proper insertion
+ SmallVector<int64_t> insertStrides =
+ (adjustedTargetShapeRank > targetShapeRank)
+ ? SmallVector<int64_t>(targetShapeRank, 1)
+ : strides;
+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
- loc, newOp->getResult(0), result, offsets, strides);
+ loc, computeResult, result, offsets, insertStrides);
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788..e0a8ac4 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,8 +20,8 @@
#define DEBUG_TYPE "xegpu"
-namespace mlir {
-namespace xegpu {
+using namespace mlir;
+using namespace mlir::xegpu;
static bool isSharedMemory(const MemRefType &memrefTy) {
Attribute attr = memrefTy.getMemorySpace();
@@ -1133,9 +1133,6 @@ LogicalResult MemDescSubviewOp::verify() {
return success();
}
-} // namespace xegpu
-} // namespace mlir
-
namespace mlir {
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
} // namespace mlir
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..4d81918 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl {
}
// Otherwise, try to load the source file.
- std::string ignored;
- unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored);
+ auto bufferOrErr = llvm::MemoryBuffer::getFile(filename);
+ if (!bufferOrErr)
+ return 0;
+ unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc());
filenameToBufId[filename] = id;
return id;
}
diff --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index cb90ef8..d52d5e7 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -49,9 +49,7 @@ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}
-void StaticVerifierFunctionEmitter::emitOpConstraints(
- ArrayRef<const Record *> opDefs) {
- NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
+void StaticVerifierFunctionEmitter::emitOpConstraints() {
emitTypeConstraints();
emitAttrConstraints();
emitPropConstraints();
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 4bbcd8e..db39c70 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -34,11 +34,9 @@ Location DebugImporter::translateFuncLocation(llvm::Function *func) {
return UnknownLoc::get(context);
// Add a fused location to link the subprogram information.
- StringAttr funcName = StringAttr::get(context, subprogram->getName());
StringAttr fileName = StringAttr::get(context, subprogram->getFilename());
return FusedLocWith<DISubprogramAttr>::get(
- {NameLoc::get(funcName),
- FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
+ {FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
translate(subprogram), context);
}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index c883baa..3236b4f 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -27,6 +27,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Parser.h"
#include <optional>
@@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc,
llvm::SourceMgr tdSrcMgr;
tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc());
tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs());
+ tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
// This class provides a context argument for the llvm::SourceMgr diagnostic
// handler.
diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
index 60b9567..1dbe7eca 100644
--- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
+++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp
@@ -31,6 +31,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include <optional>
using namespace mlir;
@@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents,
llvm::append_range(includeDirs, extraDirs);
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) {
diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
index 3080b78..2d817be 100644
--- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
+++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp
@@ -17,6 +17,7 @@
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
+#include "llvm/Support/VirtualFileSystem.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
#include <optional>
@@ -448,6 +449,7 @@ void TableGenTextFile::initialize(
return;
}
sourceMgr.setIncludeDirs(includeDirs);
+ sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem());
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
// This class provides a context argument for the SourceMgr diagnostic