diff options
Diffstat (limited to 'mlir/lib')
54 files changed, 781 insertions, 234 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 1eca43d..3a307a0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) { .Case([](Float6E2M3FNType) { return 2u; }) .Case([](Float6E3M2FNType) { return 3u; }) .Case([](Float4E2M1FNType) { return 4u; }) - .Default([](Type) { return std::nullopt; }); + .Default(std::nullopt); } /// If there is a scaled MFMA instruction for the input element types `aType` @@ -1043,7 +1043,7 @@ wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); } - llvm_unreachable("Unsupported k value"); + return std::nullopt; } /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` @@ -1135,7 +1135,7 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType, return std::nullopt; } - llvm_unreachable("Unsupported k value"); + return std::nullopt; } /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` @@ -1164,7 +1164,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, elemDestType, k); - llvm_unreachable("unhandled WMMA case"); + return std::nullopt; } namespace { diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 247dba1..cfdcd9c 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) { current = op.getSource(); return false; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!skipOp) { break; diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 25f1e1b..425594b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { } return std::nullopt; }) - .Default([](auto) { return std::nullopt; }); + .Default(std::nullopt); } static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index a9efada..ec182f1 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value count = truncToI32(b, adaptor.getCount()); - if (isMbarrierShared(mbarrierType)) { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( - op, barrier, count, adaptor.getPredicate()); - } else { - rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, - adaptor.getPredicate()); - } + rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, + adaptor.getPredicate()); return success(); } }; diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 7d0a236..76a822b 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -27,6 +28,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/Support/DebugLog.h" #include <optional> @@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, bool seenSideeffects = false; // Whether we have left a nesting scope (and hence are no longer innermost). bool leftNestingScope = false; + LocalAliasAnalysis aliasAnalysis; + llvm::DenseSet<Value> writtenBuffer; while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); // Now walk over the body and clone it. // TODO: This is only correct if there either is no further scf.parallel - // nested or this code is side-effect free. Otherwise we might need - // predication. We are overly conservative for now and only allow - // side-effects in the innermost scope. + // nested or this code has side-effect but the memory buffer is not + // alias to inner loop access buffer. Otherwise we might need + // predication. if (auto nestedParallel = dyn_cast<ParallelOp>(op)) { // Before entering a nested scope, make sure there have been no - // sideeffects until now. - if (seenSideeffects) - return failure(); + // sideeffects until now or the nested operations do not access the + // buffer written by outer scope. + if (seenSideeffects) { + WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) { + if (isMemoryEffectFree(nestedOp)) + return WalkResult::advance(); + + auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp); + if (!memEffectInterface) + return WalkResult::advance(); + + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Read>(effect.getEffect()) || + isa<MemoryEffects::Write>(effect.getEffect())) { + Value baseBuffer = effect.getValue(); + if (!baseBuffer) + return WalkResult::interrupt(); + for (Value val : writtenBuffer) { + if (aliasAnalysis.alias(baseBuffer, val) != + AliasResult::NoAlias) { + return WalkResult::interrupt(); + } + } + } + } + return WalkResult::advance(); + }); + if (walkRes.wasInterrupted()) + return failure(); + } // A nested scf.parallel needs insertion of code to compute indices. // Insert that now. This will also update the worklist with the loops // body. @@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, rewriter.setInsertionPointAfter(parent); leftNestingScope = true; seenSideeffects = false; + writtenBuffer.clear(); } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) { // Convert scf.reduction op auto parentLoop = op->getParentOfType<ParallelOp>(); @@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, Operation *clone = rewriter.clone(*op, cloningMap); cloningMap.map(op->getResults(), clone->getResults()); // Check for side effects. + if (!isMemoryEffectFree(clone)) { + // Record the buffer accessed by the operations with write effects. + if (auto memEffectInterface = + dyn_cast<MemoryEffectOpInterface>(clone)) { + SmallVector<MemoryEffects::EffectInstance> effects; + memEffectInterface.getEffects(effects); + for (const MemoryEffects::EffectInstance &effect : effects) { + if (isa<MemoryEffects::Write>(effect.getEffect())) { + Value writtenBase = effect.getValue(); + // Conservatively return failure if we cannot find the written + // address. + if (!writtenBase) + return failure(); + writtenBuffer.insert(writtenBase); + } + } + } + } // TODO: Handle region side effects properly. seenSideeffects |= !isMemoryEffectFree(clone) || clone->getNumRegions() != 0; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 41d8d53..69a317ec 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -716,7 +716,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc, llvmType, accumulator); return LLVMRedIntrinOp::create(rewriter, loc, llvmType, - /*startValue=*/accumulator, vectorOperand, + /*start_value=*/accumulator, vectorOperand, fmf); } @@ -743,7 +743,7 @@ static Value lowerPredicatedReductionWithStartValue( Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, - /*startValue=*/accumulator, vectorOperand, + /*satrt_value=*/accumulator, vectorOperand, mask, vectorLength); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d80..91c1aa5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) { [](auto floatAttr) { return floatAttr.getValue().isZero(); }) .Case<IntegerAttr>( [](auto intAttr) { return intAttr.getValue().isZero(); }) - .Default([](auto) { return false; }); + .Default(false); } static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66d..33e8f2e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern // If source is a memref, we need to extract the aligned pointer as index. // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { - if (!sourceMemrefTy.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, "Expected static memref shape."); + if (!sourceMemrefTy.hasRank()) { + return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index e08cc6f..d428fbf 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1106,10 +1106,7 @@ static bool isUniformDefinition(Value value, return false; } - if (!value.getType().isIntOrIndexOrFloat()) - return false; - - return true; + return value.getType().isIntOrIndexOrFloat(); } /// Generates a broadcast op for the provided uniform value using the diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 898d76c..980442e 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) - .Default([](Operation *op) { return std::nullopt; }); + .Default(std::nullopt); if (!maybeKind) { return std::nullopt; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index d9d6934..8655ed3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; - return dyn_cast_or_null<FuncOp>( - symbolTables.lookupNearestSymbolFrom(callOp, sym)); + return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables)); } /// Return the FuncOp called by `callOp`. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index fb7f2bb..9ccbfd3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -620,7 +620,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, LDBG() << "\n- bufferizes out-of-place due to parallel region:\n" << " unConflictingWrite = operand " << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner(); + << OpWithFlags(uConflictingWrite->getOwner(), + OpPrintingFlags().skipRegions()); return true; } } @@ -631,7 +632,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, Operation *readingOp = uRead->getOwner(); LDBG() << "\n- check conflict:\n" << " uRead = operand " << uRead->getOperandNumber() << " of " - << *readingOp; + << OpWithFlags(readingOp, OpPrintingFlags().skipRegions()); // Find the definition of uRead by following the SSA use-def chain. // E.g.: @@ -655,7 +656,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, for (OpOperand *uConflictingWrite : usesWrite) { LDBG() << " unConflictingWrite = operand " << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner(); + << OpWithFlags(uConflictingWrite->getOwner(), + OpPrintingFlags().skipRegions()); // Check if op dominance can be used to rule out read-after-write // conflicts. @@ -975,7 +977,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo) { LDBG() << "//===-------------------------------------------===//\n" << "Analyzing operand #" << operand.getOperandNumber() << " of " - << *operand.getOwner(); + << OpWithFlags(operand.getOwner(), OpPrintingFlags().skipRegions()); bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, state) || diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index aa53f94..c233e24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) { static func::FuncOp getCalledFunction(func::CallOp callOp, mlir::SymbolTableCollection &symbolTable) { - SymbolRefAttr sym = - llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); - if (!sym) - return nullptr; return dyn_cast_or_null<func::FuncOp>( - symbolTable.lookupNearestSymbolFrom(callOp, sym)); + callOp.resolveCallableInTable(&symbolTable)); } /// Return "true" if the given function signature has tensor semantics. diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index 47740d3..e9da135 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp + StructuralTypeConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp new file mode 100644 index 0000000..5e2a742 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,169 @@ +//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard and builtin dialects +// into the LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Helper function for converting branch ops. This function converts the +/// signature of the given block. If the new block signature is different from +/// `expectedTypes`, returns "failure". +static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, + const TypeConverter *converter, + Operation *branchOp, Block *block, + TypeRange expectedTypes) { + assert(converter && "expected non-null type converter"); + assert(!block->isEntryBlock() && "entry blocks have no predecessors"); + + // There is nothing to do if the types already match. + if (block->getArgumentTypes() == expectedTypes) + return block; + + // Compute the new block argument types and convert the block. + std::optional<TypeConverter::SignatureConversion> conversion = + converter->convertBlockSignature(block); + if (!conversion) + return rewriter.notifyMatchFailure(branchOp, + "could not compute block signature"); + if (expectedTypes != conversion->getConvertedTypes()) + return rewriter.notifyMatchFailure( + branchOp, + "mismatch between adaptor operand types and computed block signature"); + return rewriter.applySignatureConversion(block, *conversion, converter); +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { + SmallVector<Value> result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + +/// Convert the destination block signature (if necessary) and change the +/// operands of the branch op. +struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> { + using OpConversionPattern<cf::BranchOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands()); + FailureOr<Block *> convertedBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), + TypeRange(ValueRange(flattenedAdaptor))); + if (failed(convertedBlock)) + return failure(); + rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor, + *convertedBlock); + return success(); + } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the branch op. +struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> { + using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptorTrue = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector<Value> flattenedAdaptorFalse = + flattenValues(adaptor.getFalseDestOperands()); + if (!llvm::hasSingleElement(adaptor.getCondition())) + return rewriter.notifyMatchFailure(op, + "expected single element condition"); + FailureOr<Block *> convertedTrueBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), + TypeRange(ValueRange(flattenedAdaptorTrue))); + if (failed(convertedTrueBlock)) + return failure(); + FailureOr<Block *> convertedFalseBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), + TypeRange(ValueRange(flattenedAdaptorFalse))); + if (failed(convertedFalseBlock)) + return failure(); + rewriter.replaceOpWithNewOp<cf::CondBranchOp>( + op, llvm::getSingleElement(adaptor.getCondition()), + flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), + *convertedTrueBlock, *convertedFalseBlock); + return success(); + } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the switch op. +struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> { + using OpConversionPattern<cf::SwitchOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Get or convert default block. + FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( + rewriter, getTypeConverter(), op, op.getDefaultDestination(), + TypeRange(adaptor.getDefaultOperands())); + if (failed(convertedDefaultBlock)) + return failure(); + + // Get or convert all case blocks. + SmallVector<Block *> caseDestinations; + SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); + for (auto it : llvm::enumerate(op.getCaseDestinations())) { + Block *b = it.value(); + FailureOr<Block *> convertedBlock = + getConvertedBlock(rewriter, getTypeConverter(), op, b, + TypeRange(caseOperands[it.index()])); + if (failed(convertedBlock)) + return failure(); + caseDestinations.push_back(*convertedBlock); + } + + rewriter.replaceOpWithNewOp<cf::SwitchOp>( + op, adaptor.getFlag(), *convertedDefaultBlock, + adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), + caseDestinations, caseOperands); + return success(); + } +}; + +} // namespace + +void mlir::cf::populateCFStructuralTypeConversions( + const TypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>( + typeConverter, patterns.getContext(), benefit); +} + +void mlir::cf::populateCFStructuralTypeConversionTarget( + const TypeConverter &typeConverter, ConversionTarget &target) { + target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>( + [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); }); +} + +void mlir::cf::populateCFStructuralTypeConversionsAndLegality( + const TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, PatternBenefit benefit) { + populateCFStructuralTypeConversions(typeConverter, patterns, benefit); + populateCFStructuralTypeConversionTarget(typeConverter, target); +} diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index d2c2138..025d1ac 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -330,7 +330,7 @@ static Value getBase(Value v) { v = op.getSrc(); return true; }) - .Default([](Operation *) { return false; }); + .Default(false); if (!shouldContinue) break; } @@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) { .Case([](memref::TransposeOp transpose) { return transpose.getIn(); }) .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>( [](auto op) { return op.getSrc(); }) - .Default([](Operation *) { return Value(); }); + .Default(nullptr); } /// Returns `true` if the given operation is known to capture the given value, @@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) { // These operations are known not to capture. .Case([](memref::DeallocOp) { return false; }) // By default, we don't know anything. - .Default([](Operation *) { return std::nullopt; }); + .Default(std::nullopt); } /// Returns `true` if the value may be captured by any of its users, i.e., if diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 81c3069..ec1571a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -416,13 +416,39 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, if (ci.clusterSize >= 32) { if (chipset.majorVersion <= 9) { // Broadcast last value from each row to next row. - // Use row mask to avoid polluting rows 1 and 3. + // Use row mask to avoid polluting row 0 (and row 2 if wave-64). dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_bcast_15, rewriter.getUnitAttr(), 0xa, allBanks, /*bound_ctrl*/ false); res = vector::makeArithReduction( rewriter, loc, gpu::convertReductionKind(mode), res, dpp); + + // For subgroupSize = 64, at this point lanes [16, 32) contain the full + // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly, + // lanes [48, 64) contain the full reduction over lanes [32, 64), but + // lanes [32, 48) do not. + // + // If subgroup size is 64 and cluster size is 64, we don't need lanes [0, + // 16) and [32, 48) to have the correct cluster-32 reduction values at + // this point, because only lane 63's value will ultimately be read in + // this full-cluster case. + // + // If subgroup size is 64 and cluster size is 32, we need to ensure that + // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction + // values (subgroup_reduce guarantees that all lanes within each cluster + // contain the final reduction value). We do this by broadcasting lane + // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48). + // + // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations + // for an illustration of how this within-cluster broadcast works with a + // swizzle. + if (ci.subgroupSize == 64 && ci.clusterSize == 32) { + res = + amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0, + /*or_mask=*/31, + /*xor_mask=*/0); + } } else if (chipset.majorVersion <= 12) { // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2). Value uint32Max = arith::ConstantOp::create( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3eae67f..2731069 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, return structType.getBody()[memberIndex]; return nullptr; }) - .Default(Type(nullptr)); + .Default(nullptr); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index cee943d..7d9058c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot, .Case<IntegerType, FloatType>([](auto type) { return type.getWidth() % 8 == 0 && type.getWidth() > 0; }) - .Default([](Type) { return false; }); + .Default(false); if (!canConvertType) return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ac35eea..ce93d18 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { // clang-format on .Case<PtrLikeTypeInterface>( [](Type type) { return isCompatiblePtrType(type); }) - .Default([](Type) { return false; }); + .Default(false); if (!result) compatibleTypes.erase(type); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4db..a5ffb9e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type, } else if (type == NVVM::MMATypes::f32) { elementType = builder.getF32Type(); numberElements = 8; + } else if (type == NVVM::MMATypes::f64) { + elementType = builder.getF64Type(); + if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) + numberElements = 1; + else + numberElements = 2; } else if (type == NVVM::MMATypes::tf32) { elementType = builder.getI32Type(); numberElements = 4; @@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() { return emitOpError() << "invalid attribute combination"; std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK( getEltype(), getFrag(), getM(), getN(), getK(), getContext()); + // Special case for f64 fragments + Type f64Ty = Float64Type::get(getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + if (getType() != f64Ty) + return emitOpError("expected destination type to be f64"); + return success(); + } + // Everything else is a struct Type dstType = LLVM::LLVMStructType::getLiteral( getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); if (getType() != dstType) @@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, } //===----------------------------------------------------------------------===// +// getPtx methods +//===----------------------------------------------------------------------===// + +std::string NVVM::MBarrierInitOp::getPtx() { + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); + return (addressSpace == NVVMMemorySpace::Shared) + ? std::string("mbarrier.init.shared.b64 [%0], %1;") + : std::string("mbarrier.init.b64 [%0], %1;"); +} + +//===----------------------------------------------------------------------===// // getIntrinsicID/getIntrinsicIDAndArgs methods //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInitOp>(op); + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) + .getAddressSpace(); + llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + ? llvm::Intrinsic::nvvm_mbarrier_init_shared + : llvm::Intrinsic::nvvm_mbarrier_init; + + // Fill the Intrinsic Args + llvm::SmallVector<llvm::Value *> args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getCount())); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::MBarrierInvalOp>(op); + unsigned addressSpace = + llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) + .getAddressSpace(); + llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) + ? llvm::Intrinsic::nvvm_mbarrier_inval_shared + : llvm::Intrinsic::nvvm_mbarrier_inval; + + return {id, {mt.lookupValue(thisOp.getAddr())}}; +} + #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { for (Value v : getRegionInputArgs()) setNameFn(v, "in"); + for (Value v : getRegionOutputArgs()) + setNameFn(v, "init"); } void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build( if (bodyBuild) buildGenericRegion(builder, result.location, *result.regions.front(), - inputs, /*outputs=*/{}, bodyBuild); + inputs, /*outputs=*/{init}, bodyBuild); } static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef<Value> operands, - bool initFirst = false) { + bool initFirst = false, bool mapInit = true) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, // If initFirst flag is enabled, we consider init as the first position of // payload operands. if (initFirst) { - payloadOpOperands.push_back(block.getArguments().back()); + if (mapInit) + payloadOpOperands.push_back(block.getArguments().back()); for (const auto &arg : block.getArguments().drop_back()) payloadOpOperands.push_back(arg); } else { payloadOpOperands = {block.getArguments().begin(), - block.getArguments().end()}; + block.getArguments().end() - int(!mapInit)}; } Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { if (payloadOpName.has_value()) { if (!result.operands.empty()) addBodyWithPayloadOp(parser, result, payloadOpName.value(), - payloadOpAttrs, - ArrayRef(result.operands).drop_back()); + payloadOpAttrs, ArrayRef(result.operands), false, + false); else result.addRegion(); } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, + bool mapInit = true) { + // `intFirst == true` implies that we want to map init arg + if (initFirst && !mapInit) + return false; // Check if the body can be printed in short form. The following 4 conditions // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { // 2) The payload op must have the same number of operands as the number of // block arguments. if (payload.getNumOperands() == 0 || - payload.getNumOperands() != body->getNumArguments()) + payload.getNumOperands() != body->getNumArguments() - int(!mapInit)) return false; // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) { } } else { for (const auto &[operand, bbArg] : - llvm::zip(payload.getOperands(), body->getArguments())) { + llvm::zip(payload.getOperands(), + body->getArguments().drop_back(int(!mapInit)))) { if (bbArg != operand) return false; } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - bool useShortForm = canUseShortForm(mapper); + bool useShortForm = + canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false); if (useShortForm) { printShortForm(p, &mapper->getOperations().front()); } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() { auto *bodyBlock = getBody(); auto blockArgs = bodyBlock->getArguments(); - // Checks if the number of `inputs` match the arity of the `mapper` region. - if (getInputs().size() != blockArgs.size()) + // Checks if the number of `inputs` + `init` match the arity of the `mapper` + // region. + if (getInputs().size() + 1 != blockArgs.size()) return emitOpError() << "expects number of operands to match the arity of " "mapper, but got: " - << getInputs().size() << " and " << blockArgs.size(); + << getInputs().size() + 1 << " and " + << blockArgs.size(); // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 794dda9..3a43382 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1958,7 +1958,7 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 }; /// Return true if either `op` or `permutation` are empty to allow a simpler /// polymorphic implementation. template <typename RelayoutOpTy> -bool isValidPackingPermutation( +static bool isValidPackingPermutation( RelayoutOpTy op, ArrayRef<int64_t> permutation, OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) { static_assert( @@ -2464,6 +2464,8 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, .setPaddingSizes(getMixedPaddingSizes()) .setPadToMultipleOf(getPadToMultipleOf()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(targetOp); auto maybePadOps = rewriteAsPaddedOp( rewriter, cast<TilingInterface>(targetOp.getOperation()), options); if (failed(maybePadOps)) { @@ -4320,9 +4322,10 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne( // InsertSliceToCopyOp //===----------------------------------------------------------------------===// template <typename OpTy> -DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { +static DiagnosedSilenceableFailure +doit(RewriterBase &rewriter, OpTy target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>() && "wrong op type"); @@ -4497,7 +4500,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne( maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op); return true; }) - .Default([&](Operation *op) { return false; }); + .Default(false); if (!supported) { DiagnosedSilenceableFailure diag = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 3e31393..75bb175 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -31,10 +31,8 @@ using namespace mlir; using namespace mlir::linalg; static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { - // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot - // trivially generalize a `linalg.map`, as it does not use the output as - // region arguments in the block. - if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) + // Bailout if `linalgOp` is already a generic. + if (isa<GenericOp>(linalgOp)) return failure(); // Check if the operation has exactly one region. if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 3e787a2..52ab92f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -288,10 +288,6 @@ FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp( return failure(); } - OpBuilder::InsertionGuard g(builder); - // Set IP after toPad because we also take the dims of toPad's output. - builder.setInsertionPointAfter(toPad); - // 1. Get the loopUpperBounds from the TilingInterface. SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index f05ffa8..6519c4f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b, tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0)); return complex::CreateOp::create(b, t, tmp, tmp); }) - .Default([](auto) { return Value(); }); + .Default(nullptr); if (!fillVal) return failure(); linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 27ccf3c..6becc1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, ValueRange{input, collapsedKernel, iZp, kZp}, ValueRange{collapsedInit}, stride, dilation); }) - .Default([](Operation *op) { return nullptr; }); + .Default(nullptr); if (!newConv) return failure(); for (auto attr : preservedAttrs) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 9d62491..cb6199f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) { [&](auto op) { return CombiningKind::MUL; }) .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; }) .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) - .Default([&](auto op) { return std::nullopt; }); + .Default(std::nullopt); } /// Check whether `outputOperand` is a reduction with a single combiner @@ -3911,21 +3911,21 @@ struct Conv1DGenerator Value lhs = vector::TransferReadOp::create( rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); - auto maybeMaskedLhs = maybeMaskXferOp( + auto *maybeMaskedLhs = maybeMaskXferOp( lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); // Read rhs slice of size {kw, c} @ [0, 0]. Value rhs = vector::TransferReadOp::create( rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); - auto maybeMaskedRhs = maybeMaskXferOp( + auto *maybeMaskedRhs = maybeMaskXferOp( rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); // Read res slice of size {n, w, c} @ [0, 0, 0]. Value res = vector::TransferReadOp::create( rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); - auto maybeMaskedRes = maybeMaskXferOp( + auto *maybeMaskedRes = maybeMaskXferOp( resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); //===------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 1208fdd..e685089 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) { vector::MaskedStoreOp, vector::TransferReadOp, vector::TransferWriteOp>( [](auto op) { return op.getBase(); }) - .Default([](auto) { return Value{}; }); + .Default(nullptr); } template <typename T> diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp index f6b4534..40e769e 100644 --- a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -22,5 +22,24 @@ std::string OpenACCSupport::getVariableName(Value v) { return acc::getVariableName(v); } +std::string OpenACCSupport::getRecipeName(RecipeKind kind, Type type, + Value var) { + if (impl) + return impl->getRecipeName(kind, type, var); + // The default implementation assumes that only type matters + // and the actual instance of variable is not relevant. + auto recipeName = acc::getRecipeName(kind, type); + if (recipeName.empty()) + emitNYI(var ? var.getLoc() : UnknownLoc::get(type.getContext()), + "variable privatization (incomplete recipe name handling)"); + return recipeName; +} + +InFlightDiagnostic OpenACCSupport::emitNYI(Location loc, const Twine &message) { + if (impl) + return impl->emitNYI(loc, message); + return mlir::emitError(loc, "not yet implemented: " + message); +} + } // namespace acc } // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index ca46629..35eba72 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -50,11 +50,11 @@ static void attachVarNameAttr(Operation *op, OpBuilder &builder, } } +template <typename T> struct MemRefPointerLikeModel - : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, - MemRefType> { + : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> { Type getElementType(Type pointer) const { - return cast<MemRefType>(pointer).getElementType(); + return cast<T>(pointer).getElementType(); } mlir::acc::VariableTypeCategory @@ -63,7 +63,7 @@ struct MemRefPointerLikeModel if (auto mappableTy = dyn_cast<MappableType>(varType)) { return mappableTy.getTypeCategory(varPtr); } - auto memrefTy = cast<MemRefType>(pointer); + auto memrefTy = cast<T>(pointer); if (!memrefTy.hasRank()) { // This memref is unranked - aka it could have any rank, including a // rank of 0 which could mean scalar. For now, return uncategorized. @@ -296,7 +296,10 @@ void OpenACCDialect::initialize() { // By attaching interfaces here, we make the OpenACC dialect dependent on // the other dialects. This is probably better than having dialects like LLVM // and memref be dependent on OpenACC. - MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); + MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>( + *getContext()); + UnrankedMemRefType::attachInterface< + MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext()); LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( *getContext()); } diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 89adda82..fbac28e 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { mlir::Operation *parentOp = region.getParentOp(); @@ -106,3 +107,51 @@ std::string mlir::acc::getVariableName(mlir::Value v) { return ""; } + +std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind, + mlir::Type type) { + assert(kind == mlir::acc::RecipeKind::private_recipe || + kind == mlir::acc::RecipeKind::firstprivate_recipe || + kind == mlir::acc::RecipeKind::reduction_recipe); + if (!llvm::isa<mlir::acc::PointerLikeType, mlir::acc::MappableType>(type)) + return ""; + + std::string recipeName; + llvm::raw_string_ostream ss(recipeName); + ss << (kind == mlir::acc::RecipeKind::private_recipe ? "privatization_" + : kind == mlir::acc::RecipeKind::firstprivate_recipe + ? "firstprivatization_" + : "reduction_"); + + // Print the type using its dialect-defined textual format. + type.print(ss); + ss.flush(); + + // Replace invalid characters (anything that's not a letter, number, or + // period) since this needs to be a valid MLIR identifier. + for (char &c : recipeName) { + if (!std::isalnum(static_cast<unsigned char>(c)) && c != '.' && c != '_') { + if (c == '?') + c = 'U'; + else if (c == '*') + c = 'Z'; + else if (c == '(' || c == ')' || c == '[' || c == ']' || c == '{' || + c == '}' || c == '<' || c == '>') + c = '_'; + else + c = 'X'; + } + } + + return recipeName; +} + +mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { + if (auto partialEntityAccessOp = + dyn_cast<PartialEntityAccessOpInterface>(val.getDefiningOp())) { + if (!partialEntityAccessOp.isCompleteView()) + return partialEntityAccessOp.getBaseEntity(); + } + + return val; +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 4ebd90d..d380c46 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) { ? forOp.getInitArgs()[opResult.getResultNumber()] : Value(); }) - .Default([&](auto op) { return Value(); }); + .Default(nullptr); } return false; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 0c8114d..938952e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { llvm::TypeSwitch<Type, Type>(getType()) .Case<spirv::CooperativeMatrixType>( [](auto coopType) { return coopType.getElementType(); }) - .Default([](Type) { return nullptr; }); + .Default(nullptr); // Case 1. -- matrices. if (coopElementType) { @@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() { llvm::TypeSwitch<Type, Type>(getMatrix().getType()) .Case<spirv::CooperativeMatrixType, spirv::MatrixType>( [](auto matrixType) { return matrixType.getElementType(); }) - .Default([](Type) { return nullptr; }); + .Default(nullptr); assert(elementType && "Unhandled type"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index f895807..d1e275d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() { return *elementSize * type.getNumElements(); return std::nullopt; }) - .Default(std::optional<int64_t>()); + .Default(std::nullopt); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 88e1ab6..cb9b7f6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) { return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op) .Case<vector::ReductionOp, vector::TransposeOp>( [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) - .Default([](Operation *) { return std::nullopt; }); + .Default(std::nullopt); } LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp index 46d0baa..61b5ad6 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp @@ -504,6 +504,14 @@ public: unsigned extraCursorVal = 0) : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {} + void setSparseEmitStrategy(SparseEmitStrategy strategy) override { + wrap->setSparseEmitStrategy(strategy); + } + + SparseEmitStrategy getSparseEmitStrategy() const override { + return wrap->getSparseEmitStrategy(); + } + SmallVector<Type> getCursorValTypes(OpBuilder &b) const override { return wrap->getCursorValTypes(b); } @@ -979,7 +987,7 @@ public: void SparseIterator::genInit(OpBuilder &b, Location l, const SparseIterator *p) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {}, getCursorValTypes(b)); @@ -994,7 +1002,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l, } Value SparseIterator::genNotEnd(OpBuilder &b, Location l) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"), getCursor(), b.getI1Type()); @@ -1005,7 +1013,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) { } void SparseIterator::locate(OpBuilder &b, Location l, Value crd) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); SmallVector<Value> args = getCursor(); args.push_back(crd); @@ -1019,7 +1027,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) { } Value SparseIterator::deref(OpBuilder &b, Location l) { - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); SmallVector<Value> args = getCursor(); Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"), @@ -1032,7 +1040,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) { ValueRange SparseIterator::forward(OpBuilder &b, Location l) { assert(!randomAccessible()); - if (emitStrategy == SparseEmitStrategy::kDebugInterface) { + if (getSparseEmitStrategy() == SparseEmitStrategy::kDebugInterface) { std::string prefix = getDebugInterfacePrefix(); Operation *next = b.create(l, b.getStringAttr(prefix + ".next"), getCursor(), getCursorValTypes(b)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h index 642cb1a..3636f3f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h @@ -177,10 +177,14 @@ protected: public: virtual ~SparseIterator() = default; - void setSparseEmitStrategy(SparseEmitStrategy strategy) { + virtual void setSparseEmitStrategy(SparseEmitStrategy strategy) { emitStrategy = strategy; } + virtual SparseEmitStrategy getSparseEmitStrategy() const { + return emitStrategy; + } + virtual std::string getDebugInterfacePrefix() const = 0; virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index ac72002..110bfdc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -41,10 +41,6 @@ using namespace mlir; using namespace mlir::tensor; -using llvm::divideCeilSigned; -using llvm::divideFloorSigned; -using llvm::mod; - /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *TensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bce964e..c607ece 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc, linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*init=*/tensorDestination); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + linalgBody.addArgument(tensorType.getElementType(), loc); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); @@ -1068,6 +1069,7 @@ struct SplatOpInterface /*inputs=*/ValueRange(), /*init=*/*tensorAlloc); Block &linalgBody = linalgOp.getMapper().emplaceBlock(); + linalgBody.addArgument(tensorType.getElementType(), loc); // Create linalg::IndexOps. rewriter.setInsertionPointToStart(&linalgBody); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp index 69e649d..bc4f5a5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> { return constantFoldPadOp<llvm::APInt>( rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad); }) - .Default(Value()); + .Default(nullptr); if (!newOp) return rewriter.notifyMatchFailure(padTensorOp, diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0aff67f..bf3810f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr); } +unsigned mlir::tosa::getBitWidth(Type type) { + if (dyn_cast<tosa::mxint8Type>(type)) + return 8; + return type.getIntOrFloatBitWidth(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index ab363ee..ddd9c70 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -31,6 +31,7 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8}; // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git @@ -625,6 +626,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp4e2m1"}; } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { return {"fp8e8m0"}; + } else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) { + return {"mxint8"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 4d0b61a..b54ed55 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op, << " shape dimension cannot be dynamic"; } - int64_t element_bits = type.getElementTypeBitWidth(); + int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type)); int64_t element_bytes = std::max(INT64_C(1), element_bits / 8); int64_t size = element_bytes * type.getNumElements(); @@ -1217,9 +1217,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { return true; } } - } else if (mlir::isa<tosa::shapeType>(type)) { + } else if (isa<tosa::shapeType>(type)) + return true; + else if (isa<tosa::mxint8Type>(type)) return true; - } return false; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ad8255a..ae3423c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4336,7 +4336,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. if (auto splat = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) - DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>()); + return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>()); // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource()); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index f9aa28d5..83406c8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" @@ -229,8 +228,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, } if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { - return emitError() - << "expected inst_data and lane_layout to have the same rank"; + return emitError() << "expected inst_data and lane_layout to have the same " + "rank, got inst_data " + << inst_data.size() << ", lane_layout " + << lane_layout.size(); } // sg_data is optional for Workgroup layout, but its presence requires @@ -569,8 +570,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, // for gather and scatter ops, Low-precision types are packed in 32-bit units. unsigned bitWidth = elementType.getIntOrFloatBitWidth(); int chunkAlignmentFactor = - bitWidth < targetinfo::packedSizeInBitsForGatherScatter - ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth + bitWidth < xegpu::uArch::generalPackedFormatBitSize + ? xegpu::uArch::generalPackedFormatBitSize / bitWidth : 1; auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding); if (scatterAttr) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 8fab255..90eae87 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/IR/Attributes.h" @@ -37,6 +36,8 @@ #include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" + namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT @@ -104,6 +105,8 @@ public: SmallVector<int> getLaneData() const; + SmallVector<int> getInstData() const; + bool isSliceLayout() const { if (!isAssigned()) return false; @@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const { [](int64_t val) { return static_cast<int>(val); }); } +SmallVector<int> LayoutInfo::getInstData() const { + if (!isAssigned()) + return {}; + return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(), + [](int64_t val) { return static_cast<int>(val); }); +} + void LayoutInfo::print(raw_ostream &os) const { if (isAssigned()) { os << storage; @@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const { SmallVector<int32_t> laneLayout; SmallVector<int32_t> laneData; + SmallVector<int32_t> instData; for (int64_t idx : permutation) { laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx])); laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); + instData.push_back(static_cast<int32_t>(getInstData()[idx])); } - return LayoutInfo( - xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData)); + return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData, + laneLayout, laneData)); } //===----------------------------------------------------------------------===// @@ -192,6 +204,28 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> { using Lattice::Lattice; }; +/// Helper Function to find a proper instruction multiple for the user-supplied +/// sg-level data shape. `candidates` are uArch allowed shapes. +/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count). +template <typename T> +int getLargestDivisor(T dim, ArrayRef<T> candidates, + ArrayRef<T> candidateMultiples = {}) { + static_assert(std::is_integral<T>::value, "T must be an integer type"); + int largest = -1; + SmallVector<T> multiples = {1}; + if (!candidateMultiples.empty()) + multiples = + SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end()); + for (T candidate : candidates) { + for (T multiple : multiples) { + int value = static_cast<int>(candidate * multiple); + if (value != 0 && dim % value == 0 && value > largest) + largest = value; + } + } + return largest; +} + /// Helper Functions to get default layouts. A `default layout` is a layout that /// is assigned to a value when the layout is not fixed by some anchor operation /// (like DPAS). @@ -200,18 +234,32 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> { /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1]. /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, - unsigned rank) { + unsigned rank, + const xegpu::uArch::uArch *uArch, + ArrayRef<int> instData) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( - xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1})); + xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1})); } return LayoutInfo(xegpu::LayoutAttr::get( - ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1})); + ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1})); +} + +static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, + unsigned rank, int subgroupSize) { + assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); + if (rank == 1) { + return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1})); + } + return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1})); } /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, + const xegpu::uArch::uArch *uArch, + ArrayRef<int> instData, + unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && @@ -221,28 +269,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (vectorTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1); + return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData); // Packing factor is determined by the element type bitwidth. - int packingFactor = 1; unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); + int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - packingFactor = - bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter - ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth - : 1; - return LayoutInfo(xegpu::LayoutAttr::get( - vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, - {1, packingFactor})); + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + {uArch->getSubgroupSize(), 1}, + {1, packingFactor})); } - if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) - packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), - {1, xegpu::targetinfo::subgroupSize}, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + {1, uArch->getSubgroupSize()}, {1, packingFactor})); } /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, + const xegpu::uArch::uArch *uArch, + ArrayRef<int> instData, + unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && @@ -252,27 +297,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (tdescTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1); + return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - + int subgroupSize = uArch->getSubgroupSize(); + int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - int packingFactor = - bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter - ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth - : 1; return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, - {1, packingFactor})); + tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor})); } - int packingFactor = - (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) - ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth - : 1; - return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), - {1, xegpu::targetinfo::subgroupSize}, - {1, packingFactor})); + return LayoutInfo(xegpu::LayoutAttr::get( + tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor})); } /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -281,25 +317,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, /// `packedSizeInBitsForDefault` /// * For B operand, the data must be packed in minimum /// `packedSizeInBitsForDpasB` -static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, - unsigned operandNum) { +static LayoutInfo +getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, + const xegpu::uArch::uArch *uArch, + ArrayRef<int> instData, unsigned packingSize) { Type elementTy = vectorTy.getElementType(); assert(elementTy.isIntOrFloat() && "Expected int or float type in DPAS operands"); - SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize}); + SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()}); // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and // must have the VNNI format. - if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < - xegpu::targetinfo::packedSizeInBitsForDpasB) { + if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) { SmallVector<int32_t, 2> data( - {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB / - elementTy.getIntOrFloatBitWidth()), + {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()), 1}); return LayoutInfo( - xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); + xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data)); } // Otherwise, return the default layout for the vector type. - return getDefaultSIMTLayoutInfo(vectorTy); + return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize); } //===----------------------------------------------------------------------===// @@ -456,7 +492,37 @@ void LayoutInfoPropagation::visitPrefetchNdOp( // Here we assign the default layout to the tensor descriptor operand of // prefetch. auto tdescTy = prefetch.getTensorDescType(); - auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy); + + auto uArch = getUArch(getChipStr(prefetch).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + + auto blockWHC = + uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); + if (!blockWHC) + prefetch.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, + bCount); + if (instWidth == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + if (tdescTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = getLargestDivisor( + static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); + if (instHeight == -1) + prefetch.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + auto prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize()); // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -475,10 +541,11 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp( reduction.emitWarning("Expecting output type to be 1D vector."); return; } + auto uArch = getUArch(xegpu::getChipStr(reduction).value_or("")); // Given that the result is 1D, the layout of the operand should be 2D with // default layout. - LayoutInfo operandLayout = - getDefaultSIMTLayoutInfo(reduction->getContext(), 2); + LayoutInfo operandLayout = getDefaultSIMTLayoutInfo( + reduction->getContext(), 2, uArch->getSubgroupSize()); propagateIfChanged(operands[0], operands[0]->meet(operandLayout)); // Accumulator should have the same layout as the result. propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); @@ -557,15 +624,53 @@ void LayoutInfoPropagation::visitDpasOp( ArrayRef<const LayoutInfoLattice *> results) { VectorType aTy = dpas.getLhsType(); VectorType bTy = dpas.getRhsType(); - propagateIfChanged( - operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0))); - propagateIfChanged( - operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1))); + + auto uArch = getUArch(getChipStr(dpas).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( + xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + + const unsigned dataALen = aTy.getShape().front(); + auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); + const int maxALen = + getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); + if (maxALen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + + const unsigned dataBLen = bTy.getShape().back(); + auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); + const int maxBLen = + getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); + if (maxBLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataA = {maxALen, subgroupSize}; + SmallVector<int> instDataB = {subgroupSize, maxBLen}; + + propagateIfChanged(operands[0], + operands[0]->meet(getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, instDataA, + uArchInstruction->getPackedFormatBitSizeA()))); + propagateIfChanged(operands[1], + operands[1]->meet(getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, instDataB, + uArchInstruction->getPackedFormatBitSizeB()))); if (operands.size() > 2) { VectorType cTy = dpas.getAccType(); - propagateIfChanged( - operands[2], - operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2))); + const unsigned dataCLen = bTy.getShape().back(); + auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); + const int maxCLen = + getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen)); + if (maxCLen == -1) + dpas.emitWarning( + "No suitable instruction multiple found for the given shape."); + SmallVector<int> instDataC = {maxALen, maxCLen}; + propagateIfChanged(operands[2], + operands[2]->meet(getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, instDataC, + uArchInstruction->getPackedFormatBitSizeB()))); } } @@ -573,7 +678,38 @@ void LayoutInfoPropagation::visitDpasOp( void LayoutInfoPropagation::visitStoreNdOp( xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands, ArrayRef<const LayoutInfoLattice *> results) { - LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType()); + + auto uArch = getUArch(getChipStr(store).value_or("")); + const auto *uArchInstruction = + dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( + uArch->getInstruction( + xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); + VectorType dataTy = store.getValueType(); + auto blockWHC = uArchInstruction->getBlockWidthHeightCount( + store.getValueType().getElementType()); + if (!blockWHC) + store.emitWarning("No known block params found for the element type."); + auto [bWidth, bHeight, bCount] = blockWHC.value(); + SmallVector<int> instData; + int instWidth = getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, + bCount); + if (instWidth == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + if (dataTy.getRank() == 1) + instData = {instWidth}; + else { + int instHeight = getLargestDivisor( + static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); + if (instHeight == -1) + store.emitWarning( + "No suitable instruction multiple found for the given shape."); + instData = {instHeight, instWidth}; + } + LayoutInfo storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData, + uArchInstruction->getPackedFormatBitSize()); // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -694,10 +830,23 @@ void LayoutInfoPropagation::visitLoadGatherOp( load.emitWarning("Not propagating, non-vector payload supplied."); return; } - LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); + auto uArch = getUArch(getChipStr(load).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) + instData.push_back(chunkSize); + else if (auto srcTdescTy = + dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { + if (srcTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + LayoutInfo layout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); // Mask operand should have 1D default layout. - LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1); + LayoutInfo maskLayout = + getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize); // Propagate the new layout to the tensor descriptor operand. if (isa<xegpu::TensorDescType>(load.getSourceType())) @@ -717,8 +866,10 @@ void LayoutInfoPropagation::visitCreateDescOp( // Need the layout of the descriptor to propagate to the operands. if (!descLayout.isAssigned()) return; + auto uArch = getUArch(getChipStr(createDesc).value_or("")); // For offset operand propagate 1D default layout. - LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1); + LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1, + uArch->getSubgroupSize()); propagateIfChanged(operands[1], operands[1]->meet(layout)); } @@ -735,18 +886,30 @@ void LayoutInfoPropagation::visitStoreScatterOp( storeScatter.emitWarning("Not propagating, non-vector payload supplied."); return; } + auto uArch = getUArch(getChipStr(storeScatter).value_or("")); + const int subgroupSize = uArch->getSubgroupSize(); + auto payloadShape = payloadTy.getShape(); if (payloadShape.size() > 1) assert( - payloadShape[0] == xegpu::targetinfo::subgroupSize && + payloadShape[0] == subgroupSize && "Expected the first dimension of 2D tensor descriptor to be equal to " "subgroup size."); - LayoutInfo payloadLayout = - getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true); + SmallVector<int> instData{subgroupSize}; + if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) + instData.push_back(chunkSize); + else if (auto dstTdescTy = + dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) { + if (dstTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), + /*scattered=*/true); LayoutInfo maskLayout = - getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1); + getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize); // Propagate the payload operand layout propagateIfChanged(operands[0], operands[0]->meet(payloadLayout)); // Propagate the destination (if tdesc) operand layout @@ -1023,9 +1186,13 @@ void XeGPUPropagateLayoutPass::runOnOperation() { LayoutInfo layout = analysis.getLayoutInfo(val); if (!layout.isAssigned()) return {}; + xegpu::DistributeLayoutAttr layoutAttr = + cast<xegpu::DistributeLayoutAttr>(layout.get()); + if (this->layoutKind == "lane") + layoutAttr = layoutAttr.dropInstData(); if (layout.isSliceLayout()) - return cast<xegpu::SliceAttr>(layout.get()); - return cast<xegpu::LayoutAttr>(layout.get()); + return cast<xegpu::SliceAttr>(layoutAttr); + return cast<xegpu::LayoutAttr>(layoutAttr); }; mlir::OpBuilder builder(&getContext()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index d09dc19..5a3b27e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -11,10 +11,10 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -159,17 +159,18 @@ static bool requirePacked(const xegpu::LayoutAttr layout) { /// Helper function to check if the layout requires a transpose effect. static bool requireTranspose(const xegpu::LayoutAttr layout, - const std::string &chipStr) { + const xegpu::uArch::uArch *uArch) { // Return false for unsupported targets. // TODO: Add more support or move to target info. - if (chipStr != "pvc" && chipStr != "bmg") + if (uArch->getName().equals_insensitive("pvc") && + uArch->getName().equals_insensitive("bmg")) return false; if (!layout) return false; auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); if (laneLayout.size() != 2) return false; - return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1; + return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1; } /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body @@ -199,6 +200,11 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern; LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, PatternRewriter &rewriter) const override { + auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or("")); + if (!uArch) + return rewriter.notifyMatchFailure( + gpuFuncOp, "Subgroup distribution requires target attribute attached " + "to set the warp size"); // If the function only contains a single void return, skip. if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) { return isa<gpu::ReturnOp>(op) && !op.getNumOperands(); @@ -230,7 +236,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> { ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults(); auto warpOp = gpu::WarpExecuteOnLane0Op::create( rewriter, laneId.getLoc(), gpuFuncResultType, laneId, - xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(), + uArch->getSubgroupSize(), newGpuFunc.getArguments(), newGpuFunc.getArgumentTypes()); Block &warpBodyBlock = warpOp.getBodyRegion().front(); // Replace the ReturnOp of the original gpu function with a YieldOp. @@ -495,14 +501,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { warpOp, "warp result is not a xegpu::LoadNd op"); auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>(); + auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or("")); + if (!uArch) + return rewriter.notifyMatchFailure( + loadOp, "xegpu::LoadNdOp require target attribute attached to " + "determine transpose " + "requirement"); // Chip information is required to decide if the layout requires transpose // effect. - auto chipStr = xegpu::getChipStr(loadOp); - if (!chipStr) - return rewriter.notifyMatchFailure( - loadOp, - "xegpu::LoadNdOp require chip information to determine transpose " - "requirement"); // Expecting offsets to be present. SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets(); if (offsets.empty()) @@ -556,7 +562,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { // Set the packed attribute if the layout requires it. newLoadOp.setPacked(requirePacked(layout)); // Set the transpose attribute if the layout requires it. - if (requireTranspose(layout, chipStr.value())) + if (requireTranspose(layout, uArch)) newLoadOp.setTranspose( DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0})); Value distributedVal = newWarpOp.getResult(operandIdx); diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 375e820..cf8a4d2 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -121,12 +121,13 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { Operation *rootOp = qs.getRootOp(); int matchCount = 0; matcher::MatchFinder finder; + + StringRef functionName = matcher.getFunctionName(); auto matches = finder.collectMatches(rootOp, std::move(matcher)); // An extract call is recognized by considering if the matcher has a name. // TODO: Consider making the extract more explicit. - if (matcher.hasFunctionName()) { - auto functionName = matcher.getFunctionName(); + if (!functionName.empty()) { std::vector<Operation *> flattenedMatches = finder.flattenMatchedOps(matches); Operation *function = diff --git a/mlir/lib/Support/Timing.cpp b/mlir/lib/Support/Timing.cpp index fb6f82c..16306d7 100644 --- a/mlir/lib/Support/Timing.cpp +++ b/mlir/lib/Support/Timing.cpp @@ -319,7 +319,6 @@ public: void mergeChildren(AsyncChildrenMap &&other) { for (auto &thread : other) { mergeChildren(std::move(thread.second)); - assert(thread.second.empty()); } other.clear(); } diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index b31377e..0f1bf83 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -56,7 +56,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const { StringRef value = init->getValue(); return value.empty() ? std::optional<StringRef>() : value; }) - .Default([](auto *) { return std::nullopt; }); + .Default(std::nullopt); } // Return the C++ type for this type (which may just be ::mlir::Type). diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index eeb8725..e3bcf27 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) { .Case<>([&](LLVM::DIGlobalVariableAttr global) { return translate(global); }) - .Default([&](Attribute attr) { return nullptr; }); + .Default(nullptr); return metadata; }; return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()), @@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) { .Case([&](LLVM::DILocalVariableAttr local) { return translate(local); }) - .Case<>([&](LLVM::DIGlobalVariableAttr global) { + .Case([&](LLVM::DIGlobalVariableAttr global) { return translate(global); }) - .Default([&](Attribute attr) { return nullptr; }); + .Default(nullptr); return metadata; }; return llvm::DIGenericSubrange::get(llvmCtx, diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index f284540..8edec99 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4084,12 +4084,13 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, /// /// Fortran /// map(tofrom: array(2:5, 3:2)) -/// or -/// C++ -/// map(tofrom: array[1:4][2:3]) +/// /// We must calculate the initial pointer offset to pass across, this function /// performs this using bounds. /// +/// TODO/WARNING: This only supports Fortran's column major indexing currently +/// as is noted in the note below and comments in the function, we must extend +/// this function when we add a C++ frontend. /// NOTE: which while specified in row-major order it currently needs to be /// flipped for Fortran's column order array allocation and access (as /// opposed to C++'s row-major, hence the backwards processing where order is @@ -4125,46 +4126,28 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, // with a pointer that's being treated like an array and we have the // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base // address (pointer pointing to the actual data) so we must caclulate the - // offset using a single index which the following two loops attempts to - // compute. - - // Calculates the size offset we need to make per row e.g. first row or - // column only needs to be offset by one, but the next would have to be - // the previous row/column offset multiplied by the extent of current row. + // offset using a single index which the following loop attempts to + // compute using the standard column-major algorithm e.g for a 3D array: // - // For example ([1][10][100]): + // ((((c_idx * b_len) + b_idx) * a_len) + a_idx) // - // - First row/column we move by 1 for each index increment - // - Second row/column we move by 1 (first row/column) * 10 (extent/size of - // current) for 10 for each index increment - // - Third row/column we would move by 10 (second row/column) * - // (extent/size of current) 100 for 1000 for each index increment - std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)}; - for (size_t i = 1; i < bounds.size(); ++i) { - if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( - bounds[i].getDefiningOp())) { - dimensionIndexSizeOffset.push_back(builder.CreateMul( - moduleTranslation.lookupValue(boundOp.getExtent()), - dimensionIndexSizeOffset[i - 1])); - } - } - - // Now that we have calculated how much we move by per index, we must - // multiply each lower bound offset in indexes by the size offset we - // have calculated in the previous and accumulate the results to get - // our final resulting offset. + // It is of note that it's doing column-major rather than row-major at the + // moment, but having a way for the frontend to indicate which major format + // to use or standardizing/canonicalizing the order of the bounds to compute + // the offset may be useful in the future when there's other frontends with + // different formats. + std::vector<llvm::Value *> dimensionIndexSizeOffset; for (int i = bounds.size() - 1; i >= 0; --i) { if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( bounds[i].getDefiningOp())) { - if (idx.empty()) - idx.emplace_back(builder.CreateMul( - moduleTranslation.lookupValue(boundOp.getLowerBound()), - dimensionIndexSizeOffset[i])); + if (i == ((int)bounds.size() - 1)) + idx.emplace_back( + moduleTranslation.lookupValue(boundOp.getLowerBound())); else idx.back() = builder.CreateAdd( - idx.back(), builder.CreateMul(moduleTranslation.lookupValue( - boundOp.getLowerBound()), - dimensionIndexSizeOffset[i])); + builder.CreateMul(idx.back(), moduleTranslation.lookupValue( + boundOp.getExtent())), + moduleTranslation.lookupValue(boundOp.getLowerBound())); } } } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 08cac1f..5790a77 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -158,7 +158,8 @@ private: /// Emit a cluster (subgraph). The specified builder generates the body of the /// cluster. Return the anchor node of the cluster. - Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { + Node emitClusterStmt(function_ref<void()> builder, + const std::string &label = "") { int clusterId = ++counter; os << "subgraph cluster_" << clusterId << " {\n"; os.indent(); @@ -269,7 +270,7 @@ private: } /// Emit a node statement. - Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, + Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode, StringRef background = "") { int nodeId = ++counter; AttributeMap attrs; |
