aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp2
-rw-r--r--mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp2
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp9
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp64
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp5
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp2
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp7
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp6
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp6
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp28
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp57
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp37
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp2
-rw-r--r--mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp10
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp4
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp11
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp293
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp30
-rw-r--r--mlir/lib/TableGen/Type.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/DebugTranslation.cpp6
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp55
-rw-r--r--mlir/lib/Transforms/ViewOpGraph.cpp5
40 files changed, 495 insertions, 201 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 41e333c..3a307a0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
.Case([](Float6E2M3FNType) { return 2u; })
.Case([](Float6E3M2FNType) { return 3u; })
.Case([](Float4E2M1FNType) { return 4u; })
- .Default([](Type) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// If there is a scaled MFMA instruction for the input element types `aType`
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 247dba1..cfdcd9c 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!skipOp) {
break;
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 25f1e1b..425594b 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
}
return std::nullopt;
})
- .Default([](auto) { return std::nullopt; });
+ .Default(std::nullopt);
}
static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index a9efada..ec182f1 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value count = truncToI32(b, adaptor.getCount());
- if (isMbarrierShared(mbarrierType)) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(
- op, barrier, count, adaptor.getPredicate());
- } else {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
- adaptor.getPredicate());
- }
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count,
+ adaptor.getPredicate());
return success();
}
};
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 7d0a236..76a822b 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPU.h"
+#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -27,6 +28,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/DebugLog.h"
#include <optional>
@@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
bool seenSideeffects = false;
// Whether we have left a nesting scope (and hence are no longer innermost).
bool leftNestingScope = false;
+ LocalAliasAnalysis aliasAnalysis;
+ llvm::DenseSet<Value> writtenBuffer;
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
// Now walk over the body and clone it.
// TODO: This is only correct if there either is no further scf.parallel
- // nested or this code is side-effect free. Otherwise we might need
- // predication. We are overly conservative for now and only allow
- // side-effects in the innermost scope.
+ // nested or this code has side-effect but the memory buffer is not
+ // alias to inner loop access buffer. Otherwise we might need
+ // predication.
if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {
// Before entering a nested scope, make sure there have been no
- // sideeffects until now.
- if (seenSideeffects)
- return failure();
+ // sideeffects until now or the nested operations do not access the
+ // buffer written by outer scope.
+ if (seenSideeffects) {
+ WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) {
+ if (isMemoryEffectFree(nestedOp))
+ return WalkResult::advance();
+
+ auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp);
+ if (!memEffectInterface)
+ return WalkResult::advance();
+
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Read>(effect.getEffect()) ||
+ isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value baseBuffer = effect.getValue();
+ if (!baseBuffer)
+ return WalkResult::interrupt();
+ for (Value val : writtenBuffer) {
+ if (aliasAnalysis.alias(baseBuffer, val) !=
+ AliasResult::NoAlias) {
+ return WalkResult::interrupt();
+ }
+ }
+ }
+ }
+ return WalkResult::advance();
+ });
+ if (walkRes.wasInterrupted())
+ return failure();
+ }
// A nested scf.parallel needs insertion of code to compute indices.
// Insert that now. This will also update the worklist with the loops
// body.
@@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
rewriter.setInsertionPointAfter(parent);
leftNestingScope = true;
seenSideeffects = false;
+ writtenBuffer.clear();
} else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {
// Convert scf.reduction op
auto parentLoop = op->getParentOfType<ParallelOp>();
@@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
Operation *clone = rewriter.clone(*op, cloningMap);
cloningMap.map(op->getResults(), clone->getResults());
// Check for side effects.
+ if (!isMemoryEffectFree(clone)) {
+ // Record the buffer accessed by the operations with write effects.
+ if (auto memEffectInterface =
+ dyn_cast<MemoryEffectOpInterface>(clone)) {
+ SmallVector<MemoryEffects::EffectInstance> effects;
+ memEffectInterface.getEffects(effects);
+ for (const MemoryEffects::EffectInstance &effect : effects) {
+ if (isa<MemoryEffects::Write>(effect.getEffect())) {
+ Value writtenBase = effect.getValue();
+ // Conservatively return failure if we cannot find the written
+ // address.
+ if (!writtenBase)
+ return failure();
+ writtenBuffer.insert(writtenBase);
+ }
+ }
+ }
+ }
// TODO: Handle region side effects properly.
seenSideeffects |=
!isMemoryEffectFree(clone) || clone->getNumRegions() != 0;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 41d8d53..69a317ec 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -716,7 +716,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,
accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,
llvmType, accumulator);
return LLVMRedIntrinOp::create(rewriter, loc, llvmType,
- /*startValue=*/accumulator, vectorOperand,
+ /*start_value=*/accumulator, vectorOperand,
fmf);
}
@@ -743,7 +743,7 @@ static Value lowerPredicatedReductionWithStartValue(
Value vectorLength =
createVectorLengthValue(rewriter, loc, vectorOperand.getType());
return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType,
- /*startValue=*/accumulator, vectorOperand,
+ /*satrt_value=*/accumulator, vectorOperand,
mask, vectorLength);
}
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d80..91c1aa5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) {
[](auto floatAttr) { return floatAttr.getValue().isZero(); })
.Case<IntegerAttr>(
[](auto intAttr) { return intAttr.getValue().isZero(); })
- .Default([](auto) { return false; });
+ .Default(false);
}
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index e08cc6f..d428fbf 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1106,10 +1106,7 @@ static bool isUniformDefinition(Value value,
return false;
}
- if (!value.getType().isIntOrIndexOrFloat())
- return false;
-
- return true;
+ return value.getType().isIntOrIndexOrFloat();
}
/// Generates a broadcast op for the provided uniform value using the
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 898d76c..980442e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
.Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
.Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
.Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
- .Default([](Operation *op) { return std::nullopt; });
+ .Default(std::nullopt);
if (!maybeKind) {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index d9d6934..8655ed3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp,
SymbolTableCollection &symbolTables) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
- return dyn_cast_or_null<FuncOp>(
- symbolTables.lookupNearestSymbolFrom(callOp, sym));
+ return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));
}
/// Return the FuncOp called by `callOp`.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index aa53f94..c233e24 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
static func::FuncOp
getCalledFunction(func::CallOp callOp,
mlir::SymbolTableCollection &symbolTable) {
- SymbolRefAttr sym =
- llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
- if (!sym)
- return nullptr;
return dyn_cast_or_null<func::FuncOp>(
- symbolTable.lookupNearestSymbolFrom(callOp, sym));
+ callOp.resolveCallableInTable(&symbolTable));
}
/// Return "true" if the given function signature has tensor semantics.
diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
index d2c2138..025d1ac 100644
--- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
@@ -330,7 +330,7 @@ static Value getBase(Value v) {
v = op.getSrc();
return true;
})
- .Default([](Operation *) { return false; });
+ .Default(false);
if (!shouldContinue)
break;
}
@@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {
.Case([](memref::TransposeOp transpose) { return transpose.getIn(); })
.Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
[](auto op) { return op.getSrc(); })
- .Default([](Operation *) { return Value(); });
+ .Default(nullptr);
}
/// Returns `true` if the given operation is known to capture the given value,
@@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {
// These operations are known not to capture.
.Case([](memref::DeallocOp) { return false; })
// By default, we don't know anything.
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Returns `true` if the value may be captured by any of its users, i.e., if
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 81c3069..ec1571a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -416,13 +416,39 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,
if (ci.clusterSize >= 32) {
if (chipset.majorVersion <= 9) {
// Broadcast last value from each row to next row.
- // Use row mask to avoid polluting rows 1 and 3.
+ // Use row mask to avoid polluting row 0 (and row 2 if wave-64).
dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,
amdgpu::DPPPerm::row_bcast_15,
rewriter.getUnitAttr(), 0xa, allBanks,
/*bound_ctrl*/ false);
res = vector::makeArithReduction(
rewriter, loc, gpu::convertReductionKind(mode), res, dpp);
+
+ // For subgroupSize = 64, at this point lanes [16, 32) contain the full
+ // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly,
+ // lanes [48, 64) contain the full reduction over lanes [32, 64), but
+ // lanes [32, 48) do not.
+ //
+ // If subgroup size is 64 and cluster size is 64, we don't need lanes [0,
+ // 16) and [32, 48) to have the correct cluster-32 reduction values at
+ // this point, because only lane 63's value will ultimately be read in
+ // this full-cluster case.
+ //
+ // If subgroup size is 64 and cluster size is 32, we need to ensure that
+ // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction
+ // values (subgroup_reduce guarantees that all lanes within each cluster
+ // contain the final reduction value). We do this by broadcasting lane
+ // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48).
+ //
+ // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations
+ // for an illustration of how this within-cluster broadcast works with a
+ // swizzle.
+ if (ci.subgroupSize == 64 && ci.clusterSize == 32) {
+ res =
+ amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0,
+ /*or_mask=*/31,
+ /*xor_mask=*/0);
+ }
} else if (chipset.majorVersion <= 12) {
// Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).
Value uint32Max = arith::ConstantOp::create(
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3eae67f..2731069 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
return structType.getBody()[memberIndex];
return nullptr;
})
- .Default(Type(nullptr));
+ .Default(nullptr);
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index cee943d..7d9058c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
.Case<IntegerType, FloatType>([](auto type) {
return type.getWidth() % 8 == 0 && type.getWidth() > 0;
})
- .Default([](Type) { return false; });
+ .Default(false);
if (!canConvertType)
return false;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index ac35eea..ce93d18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {
// clang-format on
.Case<PtrLikeTypeInterface>(
[](Type type) { return isCompatiblePtrType(type); })
- .Default([](Type) { return false; });
+ .Default(false);
if (!result)
compatibleTypes.erase(type);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index f0de4db..a5ffb9e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
} else if (type == NVVM::MMATypes::f32) {
elementType = builder.getF32Type();
numberElements = 8;
+ } else if (type == NVVM::MMATypes::f64) {
+ elementType = builder.getF64Type();
+ if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
+ numberElements = 1;
+ else
+ numberElements = 2;
} else if (type == NVVM::MMATypes::tf32) {
elementType = builder.getI32Type();
numberElements = 4;
@@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(
getEltype(), getFrag(), getM(), getN(), getK(), getContext());
+ // Special case for f64 fragments
+ Type f64Ty = Float64Type::get(getContext());
+ if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+ if (getType() != f64Ty)
+ return emitOpError("expected destination type to be f64");
+ return success();
+ }
+ // Everything else is a struct
Type dstType = LLVM::LLVMStructType::getLiteral(
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
@@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
}
//===----------------------------------------------------------------------===//
+// getPtx methods
+//===----------------------------------------------------------------------===//
+
+std::string NVVM::MBarrierInitOp::getPtx() {
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
+ return (addressSpace == NVVMMemorySpace::Shared)
+ ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
+}
+
+//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
+mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInitOp>(op);
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
+ .getAddressSpace();
+ llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ ? llvm::Intrinsic::nvvm_mbarrier_init_shared
+ : llvm::Intrinsic::nvvm_mbarrier_init;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getAddr()));
+ args.push_back(mt.lookupValue(thisOp.getCount()));
+
+ return {id, std::move(args)};
+}
+
+mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::MBarrierInvalOp>(op);
+ unsigned addressSpace =
+ llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType())
+ .getAddressSpace();
+ llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared)
+ ? llvm::Intrinsic::nvvm_mbarrier_inval_shared
+ : llvm::Intrinsic::nvvm_mbarrier_inval;
+
+ return {id, {mt.lookupValue(thisOp.getAddr())}};
+}
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cbc565b..3dc45ed 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (Value v : getRegionInputArgs())
setNameFn(v, "in");
+ for (Value v : getRegionOutputArgs())
+ setNameFn(v, "init");
}
void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
@@ -1495,14 +1497,14 @@ void MapOp::build(
if (bodyBuild)
buildGenericRegion(builder, result.location, *result.regions.front(),
- inputs, /*outputs=*/{}, bodyBuild);
+ inputs, /*outputs=*/{init}, bodyBuild);
}
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
const OperationName &payloadOpName,
const NamedAttrList &payloadOpAttrs,
ArrayRef<Value> operands,
- bool initFirst = false) {
+ bool initFirst = false, bool mapInit = true) {
OpBuilder b(parser.getContext());
Region *body = result.addRegion();
Block &block = body->emplaceBlock();
@@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
// If initFirst flag is enabled, we consider init as the first position of
// payload operands.
if (initFirst) {
- payloadOpOperands.push_back(block.getArguments().back());
+ if (mapInit)
+ payloadOpOperands.push_back(block.getArguments().back());
for (const auto &arg : block.getArguments().drop_back())
payloadOpOperands.push_back(arg);
} else {
payloadOpOperands = {block.getArguments().begin(),
- block.getArguments().end()};
+ block.getArguments().end() - int(!mapInit)};
}
Operation *payloadOp = b.create(
@@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
if (payloadOpName.has_value()) {
if (!result.operands.empty())
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
- payloadOpAttrs,
- ArrayRef(result.operands).drop_back());
+ payloadOpAttrs, ArrayRef(result.operands), false,
+ false);
else
result.addRegion();
} else {
@@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-static bool canUseShortForm(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false,
+ bool mapInit = true) {
+ // `intFirst == true` implies that we want to map init arg
+ if (initFirst && !mapInit)
+ return false;
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:
@@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
- payload.getNumOperands() != body->getNumArguments())
+ payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
return false;
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
}
} else {
for (const auto &[operand, bbArg] :
- llvm::zip(payload.getOperands(), body->getArguments())) {
+ llvm::zip(payload.getOperands(),
+ body->getArguments().drop_back(int(!mapInit)))) {
if (bbArg != operand)
return false;
}
@@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- bool useShortForm = canUseShortForm(mapper);
+ bool useShortForm =
+ canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}
@@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {
auto *bodyBlock = getBody();
auto blockArgs = bodyBlock->getArguments();
- // Checks if the number of `inputs` match the arity of the `mapper` region.
- if (getInputs().size() != blockArgs.size())
+ // Checks if the number of `inputs` + `init` match the arity of the `mapper`
+ // region.
+ if (getInputs().size() + 1 != blockArgs.size())
return emitOpError() << "expects number of operands to match the arity of "
"mapper, but got: "
- << getInputs().size() << " and " << blockArgs.size();
+ << getInputs().size() + 1 << " and "
+ << blockArgs.size();
// The parameters of mapper should all match the element type of inputs.
for (const auto &[bbArgType, inputArg] :
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8b89244..3a43382 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1958,7 +1958,7 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
/// Return true if either `op` or `permutation` are empty to allow a simpler
/// polymorphic implementation.
template <typename RelayoutOpTy>
-bool isValidPackingPermutation(
+static bool isValidPackingPermutation(
RelayoutOpTy op, ArrayRef<int64_t> permutation,
OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
static_assert(
@@ -4322,9 +4322,10 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//
template <typename OpTy>
-DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
+static DiagnosedSilenceableFailure
+doit(RewriterBase &rewriter, OpTy target,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
tensor::ParallelInsertSliceOp>() &&
"wrong op type");
@@ -4499,7 +4500,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
return true;
})
- .Default([&](Operation *op) { return false; });
+ .Default(false);
if (!supported) {
DiagnosedSilenceableFailure diag =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 3e31393..75bb175 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -31,10 +31,8 @@ using namespace mlir;
using namespace mlir::linalg;
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) {
- // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot
- // trivially generalize a `linalg.map`, as it does not use the output as
- // region arguments in the block.
- if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
+ // Bailout if `linalgOp` is already a generic.
+ if (isa<GenericOp>(linalgOp))
return failure();
// Check if the operation has exactly one region.
if (linalgOp->getNumRegions() != 1) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index f05ffa8..6519c4f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,
tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));
return complex::CreateOp::create(b, t, tmp, tmp);
})
- .Default([](auto) { return Value(); });
+ .Default(nullptr);
if (!fillVal)
return failure();
linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
index 27ccf3c..6becc1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp
@@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
- .Default([](Operation *op) { return nullptr; });
+ .Default(nullptr);
if (!newConv)
return failure();
for (auto attr : preservedAttrs)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0f317ea..cb6199f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {
[&](auto op) { return CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
.Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
- .Default([&](auto op) { return std::nullopt; });
+ .Default(std::nullopt);
}
/// Check whether `outputOperand` is a reduction with a single combiner
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
index 1208fdd..e685089 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
[](auto op) { return op.getBase(); })
- .Default([](auto) { return Value{}; });
+ .Default(nullptr);
}
template <typename T>
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index 660c313..fbac28e 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -145,3 +145,13 @@ std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind,
return recipeName;
}
+
+mlir::Value mlir::acc::getBaseEntity(mlir::Value val) {
+ if (auto partialEntityAccessOp =
+ dyn_cast<PartialEntityAccessOpInterface>(val.getDefiningOp())) {
+ if (!partialEntityAccessOp.isCompleteView())
+ return partialEntityAccessOp.getBaseEntity();
+ }
+
+ return val;
+}
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 4ebd90d..d380c46 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {
? forOp.getInitArgs()[opResult.getResultNumber()]
: Value();
})
- .Default([&](auto op) { return Value(); });
+ .Default(nullptr);
}
return false;
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0c8114d..938952e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
llvm::TypeSwitch<Type, Type>(getType())
.Case<spirv::CooperativeMatrixType>(
[](auto coopType) { return coopType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
// Case 1. -- matrices.
if (coopElementType) {
@@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
.Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
[](auto matrixType) { return matrixType.getElementType(); })
- .Default([](Type) { return nullptr; });
+ .Default(nullptr);
assert(elementType && "Unhandled type");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index f895807..d1e275d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {
return *elementSize * type.getNumElements();
return std::nullopt;
})
- .Default(std::optional<int64_t>());
+ .Default(std::nullopt);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 88e1ab6..cb9b7f6 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {
return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)
.Case<vector::ReductionOp, vector::TransposeOp>(
[](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
- .Default([](Operation *) { return std::nullopt; });
+ .Default(std::nullopt);
}
LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index ac72002..110bfdc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -41,10 +41,6 @@
using namespace mlir;
using namespace mlir::tensor;
-using llvm::divideCeilSigned;
-using llvm::divideFloorSigned;
-using llvm::mod;
-
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index bce964e..c607ece 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/tensorDestination);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
@@ -1068,6 +1069,7 @@ struct SplatOpInterface
/*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+ linalgBody.addArgument(tensorType.getElementType(), loc);
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 69e649d..bc4f5a5 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
- .Default(Value());
+ .Default(nullptr);
if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ad8255a..ae3423c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4336,7 +4336,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
// ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
if (auto splat =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
+ return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa28d5..83406c8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -11,7 +11,6 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -229,8 +228,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
- return emitError()
- << "expected inst_data and lane_layout to have the same rank";
+ return emitError() << "expected inst_data and lane_layout to have the same "
+ "rank, got inst_data "
+ << inst_data.size() << ", lane_layout "
+ << lane_layout.size();
}
// sg_data is optional for Workgroup layout, but its presence requires
@@ -569,8 +570,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int chunkAlignmentFactor =
- bitWidth < targetinfo::packedSizeInBitsForGatherScatter
- ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
+ bitWidth < xegpu::uArch::generalPackedFormatBitSize
+ ? xegpu::uArch::generalPackedFormatBitSize / bitWidth
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8fab255..90eae87 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -14,7 +14,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/IR/Attributes.h"
@@ -37,6 +36,8 @@
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+
namespace mlir {
namespace xegpu {
#define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT
@@ -104,6 +105,8 @@ public:
SmallVector<int> getLaneData() const;
+ SmallVector<int> getInstData() const;
+
bool isSliceLayout() const {
if (!isAssigned())
return false;
@@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {
[](int64_t val) { return static_cast<int>(val); });
}
+SmallVector<int> LayoutInfo::getInstData() const {
+ if (!isAssigned())
+ return {};
+ return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
os << storage;
@@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneLayout;
SmallVector<int32_t> laneData;
+ SmallVector<int32_t> instData;
for (int64_t idx : permutation) {
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+ instData.push_back(static_cast<int32_t>(getInstData()[idx]));
}
- return LayoutInfo(
- xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData));
+ return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
+ laneLayout, laneData));
}
//===----------------------------------------------------------------------===//
@@ -192,6 +204,28 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
using Lattice::Lattice;
};
+/// Helper Function to find a proper instruction multiple for the user-supplied
+/// sg-level data shape. `candidates` are uArch allowed shapes.
+/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
+template <typename T>
+int getLargestDivisor(T dim, ArrayRef<T> candidates,
+ ArrayRef<T> candidateMultiples = {}) {
+ static_assert(std::is_integral<T>::value, "T must be an integer type");
+ int largest = -1;
+ SmallVector<T> multiples = {1};
+ if (!candidateMultiples.empty())
+ multiples =
+ SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
+ for (T candidate : candidates) {
+ for (T multiple : multiples) {
+ int value = static_cast<int>(candidate * multiple);
+ if (value != 0 && dim % value == 0 && value > largest)
+ largest = value;
+ }
+ }
+ return largest;
+}
+
/// Helper Functions to get default layouts. A `default layout` is a layout that
/// is assigned to a value when the layout is not fixed by some anchor operation
/// (like DPAS).
@@ -200,18 +234,32 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
- unsigned rank) {
+ unsigned rank,
+ const xegpu::uArch::uArch *uArch,
+ ArrayRef<int> instData) {
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
if (rank == 1) {
return LayoutInfo(
- xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1}));
+ xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
}
return LayoutInfo(xegpu::LayoutAttr::get(
- ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1}));
+ ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+}
+
+static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
+ unsigned rank, int subgroupSize) {
+ assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+ if (rank == 1) {
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1}));
+ }
+ return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch,
+ ArrayRef<int> instData,
+ unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
@@ -221,28 +269,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (vectorTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1);
+ return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
// Packing factor is determined by the element type bitwidth.
- int packingFactor = 1;
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
+ int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- packingFactor =
- bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
- ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
- : 1;
- return LayoutInfo(xegpu::LayoutAttr::get(
- vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ {uArch->getSubgroupSize(), 1},
+ {1, packingFactor}));
}
- if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
- packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth;
- return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
- {1, xegpu::targetinfo::subgroupSize},
+ return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+ {1, uArch->getSubgroupSize()},
{1, packingFactor}));
}
/// Helper to get the default layout for a vector type.
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
+ const xegpu::uArch::uArch *uArch,
+ ArrayRef<int> instData,
+ unsigned packingSize,
bool isScattered = false) {
// Expecting a 1D or 2D vector.
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
@@ -252,27 +297,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
if (tdescTy.getRank() == 1)
- return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1);
+ return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
// Packing factor is determined by the element type bitwidth.
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
-
+ int subgroupSize = uArch->getSubgroupSize();
+ int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
if (isScattered) {
- int packingFactor =
- bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter
- ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth
- : 1;
return LayoutInfo(xegpu::LayoutAttr::get(
- tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1},
- {1, packingFactor}));
+ tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
}
- int packingFactor =
- (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
- ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
- : 1;
- return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(),
- {1, xegpu::targetinfo::subgroupSize},
- {1, packingFactor}));
+ return LayoutInfo(xegpu::LayoutAttr::get(
+ tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
}
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -281,25 +317,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
/// `packedSizeInBitsForDefault`
/// * For B operand, the data must be packed in minimum
/// `packedSizeInBitsForDpasB`
-static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
- unsigned operandNum) {
+static LayoutInfo
+getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
+ const xegpu::uArch::uArch *uArch,
+ ArrayRef<int> instData, unsigned packingSize) {
Type elementTy = vectorTy.getElementType();
assert(elementTy.isIntOrFloat() &&
"Expected int or float type in DPAS operands");
- SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize});
+ SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
// must have the VNNI format.
- if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() <
- xegpu::targetinfo::packedSizeInBitsForDpasB) {
+ if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
SmallVector<int32_t, 2> data(
- {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB /
- elementTy.getIntOrFloatBitWidth()),
+ {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
1});
return LayoutInfo(
- xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
+ xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getDefaultSIMTLayoutInfo(vectorTy);
+ return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -456,7 +492,37 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
// Here we assign the default layout to the tensor descriptor operand of
// prefetch.
auto tdescTy = prefetch.getTensorDescType();
- auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
+
+ auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch));
+
+ auto blockWHC =
+ uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType());
+ if (!blockWHC)
+ prefetch.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
+ bCount);
+ if (instWidth == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ if (tdescTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = getLargestDivisor(
+ static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ prefetch.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+ auto prefetchLayout = getDefaultSIMTLayoutInfo(
+ tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
// Propagate the layout to the source tensor descriptor.
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
}
@@ -475,10 +541,11 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
reduction.emitWarning("Expecting output type to be 1D vector.");
return;
}
+ auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
// Given that the result is 1D, the layout of the operand should be 2D with
// default layout.
- LayoutInfo operandLayout =
- getDefaultSIMTLayoutInfo(reduction->getContext(), 2);
+ LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
+ reduction->getContext(), 2, uArch->getSubgroupSize());
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
// Accumulator should have the same layout as the result.
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -557,15 +624,53 @@ void LayoutInfoPropagation::visitDpasOp(
ArrayRef<const LayoutInfoLattice *> results) {
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
- propagateIfChanged(
- operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
- propagateIfChanged(
- operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
+
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+ if (maxALen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
+ const int maxBLen =
+ getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+ if (maxBLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataA = {maxALen, subgroupSize};
+ SmallVector<int> instDataB = {subgroupSize, maxBLen};
+
+ propagateIfChanged(operands[0],
+ operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
+ aTy, 0, uArch, instDataA,
+ uArchInstruction->getPackedFormatBitSizeA())));
+ propagateIfChanged(operands[1],
+ operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
+ bTy, 1, uArch, instDataB,
+ uArchInstruction->getPackedFormatBitSizeB())));
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
- propagateIfChanged(
- operands[2],
- operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen =
+ getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1)
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ SmallVector<int> instDataC = {maxALen, maxCLen};
+ propagateIfChanged(operands[2],
+ operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, instDataC,
+ uArchInstruction->getPackedFormatBitSizeB())));
}
}
@@ -573,7 +678,38 @@ void LayoutInfoPropagation::visitDpasOp(
void LayoutInfoPropagation::visitStoreNdOp(
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
- LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
+
+ auto uArch = getUArch(getChipStr(store).value_or(""));
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
+ uArch->getInstruction(
+ xegpu::uArch::InstructionKind::Subgroup2DBlockStore));
+ VectorType dataTy = store.getValueType();
+ auto blockWHC = uArchInstruction->getBlockWidthHeightCount(
+ store.getValueType().getElementType());
+ if (!blockWHC)
+ store.emitWarning("No known block params found for the element type.");
+ auto [bWidth, bHeight, bCount] = blockWHC.value();
+ SmallVector<int> instData;
+ int instWidth = getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
+ bCount);
+ if (instWidth == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ if (dataTy.getRank() == 1)
+ instData = {instWidth};
+ else {
+ int instHeight = getLargestDivisor(
+ static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
+ if (instHeight == -1)
+ store.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ instData = {instHeight, instWidth};
+ }
+ LayoutInfo storeLayout =
+ getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
+ uArchInstruction->getPackedFormatBitSize());
// Both operands should have the same layout
for (LayoutInfoLattice *operand : operands)
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -694,10 +830,23 @@ void LayoutInfoPropagation::visitLoadGatherOp(
load.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
- LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true);
+ auto uArch = getUArch(getChipStr(load).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto srcTdescTy =
+ dyn_cast<xegpu::TensorDescType>(load.getSourceType())) {
+ if (srcTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered*/ true);
// Mask operand should have 1D default layout.
- LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1);
+ LayoutInfo maskLayout =
+ getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
// Propagate the new layout to the tensor descriptor operand.
if (isa<xegpu::TensorDescType>(load.getSourceType()))
@@ -717,8 +866,10 @@ void LayoutInfoPropagation::visitCreateDescOp(
// Need the layout of the descriptor to propagate to the operands.
if (!descLayout.isAssigned())
return;
+ auto uArch = getUArch(getChipStr(createDesc).value_or(""));
// For offset operand propagate 1D default layout.
- LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1);
+ LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
+ uArch->getSubgroupSize());
propagateIfChanged(operands[1], operands[1]->meet(layout));
}
@@ -735,18 +886,30 @@ void LayoutInfoPropagation::visitStoreScatterOp(
storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
return;
}
+ auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+
auto payloadShape = payloadTy.getShape();
if (payloadShape.size() > 1)
assert(
- payloadShape[0] == xegpu::targetinfo::subgroupSize &&
+ payloadShape[0] == subgroupSize &&
"Expected the first dimension of 2D tensor descriptor to be equal to "
"subgroup size.");
- LayoutInfo payloadLayout =
- getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true);
+ SmallVector<int> instData{subgroupSize};
+ if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
+ instData.push_back(chunkSize);
+ else if (auto dstTdescTy =
+ dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
+ if (dstTdescTy.getChunkSizeAsInt() > 1)
+ instData.push_back(chunkSize);
+ }
+ LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo(
+ payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
+ /*scattered=*/true);
LayoutInfo maskLayout =
- getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1);
+ getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
// Propagate the payload operand layout
propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
// Propagate the destination (if tdesc) operand layout
@@ -1023,9 +1186,13 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
LayoutInfo layout = analysis.getLayoutInfo(val);
if (!layout.isAssigned())
return {};
+ xegpu::DistributeLayoutAttr layoutAttr =
+ cast<xegpu::DistributeLayoutAttr>(layout.get());
+ if (this->layoutKind == "lane")
+ layoutAttr = layoutAttr.dropInstData();
if (layout.isSliceLayout())
- return cast<xegpu::SliceAttr>(layout.get());
- return cast<xegpu::LayoutAttr>(layout.get());
+ return cast<xegpu::SliceAttr>(layoutAttr);
+ return cast<xegpu::LayoutAttr>(layoutAttr);
};
mlir::OpBuilder builder(&getContext());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc19..5a3b27e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -11,10 +11,10 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -159,17 +159,18 @@ static bool requirePacked(const xegpu::LayoutAttr layout) {
/// Helper function to check if the layout requires a transpose effect.
static bool requireTranspose(const xegpu::LayoutAttr layout,
- const std::string &chipStr) {
+ const xegpu::uArch::uArch *uArch) {
// Return false for unsupported targets.
// TODO: Add more support or move to target info.
- if (chipStr != "pvc" && chipStr != "bmg")
+ if (uArch->getName().equals_insensitive("pvc") &&
+ uArch->getName().equals_insensitive("bmg"))
return false;
if (!layout)
return false;
auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
if (laneLayout.size() != 2)
return false;
- return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1;
+ return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
}
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
@@ -199,6 +200,11 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,
PatternRewriter &rewriter) const override {
+ auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ gpuFuncOp, "Subgroup distribution requires target attribute attached "
+ "to set the warp size");
// If the function only contains a single void return, skip.
if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {
return isa<gpu::ReturnOp>(op) && !op.getNumOperands();
@@ -230,7 +236,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {
ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();
auto warpOp = gpu::WarpExecuteOnLane0Op::create(
rewriter, laneId.getLoc(), gpuFuncResultType, laneId,
- xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(),
+ uArch->getSubgroupSize(), newGpuFunc.getArguments(),
newGpuFunc.getArgumentTypes());
Block &warpBodyBlock = warpOp.getBodyRegion().front();
// Replace the ReturnOp of the original gpu function with a YieldOp.
@@ -495,14 +501,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
warpOp, "warp result is not a xegpu::LoadNd op");
auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>();
+ auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or(""));
+ if (!uArch)
+ return rewriter.notifyMatchFailure(
+ loadOp, "xegpu::LoadNdOp require target attribute attached to "
+ "determine transpose "
+ "requirement");
// Chip information is required to decide if the layout requires transpose
// effect.
- auto chipStr = xegpu::getChipStr(loadOp);
- if (!chipStr)
- return rewriter.notifyMatchFailure(
- loadOp,
- "xegpu::LoadNdOp require chip information to determine transpose "
- "requirement");
// Expecting offsets to be present.
SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();
if (offsets.empty())
@@ -556,7 +562,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
// Set the packed attribute if the layout requires it.
newLoadOp.setPacked(requirePacked(layout));
// Set the transpose attribute if the layout requires it.
- if (requireTranspose(layout, chipStr.value()))
+ if (requireTranspose(layout, uArch))
newLoadOp.setTranspose(
DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
Value distributedVal = newWarpOp.getResult(operandIdx);
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index b31377e..0f1bf83 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -56,7 +56,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
StringRef value = init->getValue();
return value.empty() ? std::optional<StringRef>() : value;
})
- .Default([](auto *) { return std::nullopt; });
+ .Default(std::nullopt);
}
// Return the C++ type for this type (which may just be ::mlir::Type).
diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
index eeb8725..e3bcf27 100644
--- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp
@@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) {
.Case<>([&](LLVM::DIGlobalVariableAttr global) {
return translate(global);
})
- .Default([&](Attribute attr) { return nullptr; });
+ .Default(nullptr);
return metadata;
};
return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()),
@@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) {
.Case([&](LLVM::DILocalVariableAttr local) {
return translate(local);
})
- .Case<>([&](LLVM::DIGlobalVariableAttr global) {
+ .Case([&](LLVM::DIGlobalVariableAttr global) {
return translate(global);
})
- .Default([&](Attribute attr) { return nullptr; });
+ .Default(nullptr);
return metadata;
};
return llvm::DIGenericSubrange::get(llvmCtx,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f284540..8edec99 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4084,12 +4084,13 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
///
/// Fortran
/// map(tofrom: array(2:5, 3:2))
-/// or
-/// C++
-/// map(tofrom: array[1:4][2:3])
+///
/// We must calculate the initial pointer offset to pass across, this function
/// performs this using bounds.
///
+/// TODO/WARNING: This only supports Fortran's column major indexing currently
+/// as is noted in the note below and comments in the function, we must extend
+/// this function when we add a C++ frontend.
/// NOTE: which while specified in row-major order it currently needs to be
/// flipped for Fortran's column order array allocation and access (as
/// opposed to C++'s row-major, hence the backwards processing where order is
@@ -4125,46 +4126,28 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
// with a pointer that's being treated like an array and we have the
// underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
// address (pointer pointing to the actual data) so we must caclulate the
- // offset using a single index which the following two loops attempts to
- // compute.
-
- // Calculates the size offset we need to make per row e.g. first row or
- // column only needs to be offset by one, but the next would have to be
- // the previous row/column offset multiplied by the extent of current row.
+ // offset using a single index which the following loop attempts to
+ // compute using the standard column-major algorithm e.g for a 3D array:
//
- // For example ([1][10][100]):
+ // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)
//
- // - First row/column we move by 1 for each index increment
- // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
- // current) for 10 for each index increment
- // - Third row/column we would move by 10 (second row/column) *
- // (extent/size of current) 100 for 1000 for each index increment
- std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
- for (size_t i = 1; i < bounds.size(); ++i) {
- if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
- bounds[i].getDefiningOp())) {
- dimensionIndexSizeOffset.push_back(builder.CreateMul(
- moduleTranslation.lookupValue(boundOp.getExtent()),
- dimensionIndexSizeOffset[i - 1]));
- }
- }
-
- // Now that we have calculated how much we move by per index, we must
- // multiply each lower bound offset in indexes by the size offset we
- // have calculated in the previous and accumulate the results to get
- // our final resulting offset.
+ // It is of note that it's doing column-major rather than row-major at the
+ // moment, but having a way for the frontend to indicate which major format
+ // to use or standardizing/canonicalizing the order of the bounds to compute
+ // the offset may be useful in the future when there's other frontends with
+ // different formats.
+ std::vector<llvm::Value *> dimensionIndexSizeOffset;
for (int i = bounds.size() - 1; i >= 0; --i) {
if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
bounds[i].getDefiningOp())) {
- if (idx.empty())
- idx.emplace_back(builder.CreateMul(
- moduleTranslation.lookupValue(boundOp.getLowerBound()),
- dimensionIndexSizeOffset[i]));
+ if (i == ((int)bounds.size() - 1))
+ idx.emplace_back(
+ moduleTranslation.lookupValue(boundOp.getLowerBound()));
else
idx.back() = builder.CreateAdd(
- idx.back(), builder.CreateMul(moduleTranslation.lookupValue(
- boundOp.getLowerBound()),
- dimensionIndexSizeOffset[i]));
+ builder.CreateMul(idx.back(), moduleTranslation.lookupValue(
+ boundOp.getExtent())),
+ moduleTranslation.lookupValue(boundOp.getLowerBound()));
}
}
}
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp
index 08cac1f..5790a77 100644
--- a/mlir/lib/Transforms/ViewOpGraph.cpp
+++ b/mlir/lib/Transforms/ViewOpGraph.cpp
@@ -158,7 +158,8 @@ private:
/// Emit a cluster (subgraph). The specified builder generates the body of the
/// cluster. Return the anchor node of the cluster.
- Node emitClusterStmt(function_ref<void()> builder, std::string label = "") {
+ Node emitClusterStmt(function_ref<void()> builder,
+ const std::string &label = "") {
int clusterId = ++counter;
os << "subgraph cluster_" << clusterId << " {\n";
os.indent();
@@ -269,7 +270,7 @@ private:
}
/// Emit a node statement.
- Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
+ Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode,
StringRef background = "") {
int nodeId = ++counter;
AttributeMap attrs;