diff options
Diffstat (limited to 'mlir/lib')
40 files changed, 495 insertions, 201 deletions
| diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 41e333c..3a307a0 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -935,7 +935,7 @@ static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {        .Case([](Float6E2M3FNType) { return 2u; })        .Case([](Float6E3M2FNType) { return 3u; })        .Case([](Float4E2M1FNType) { return 4u; }) -      .Default([](Type) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// If there is a scaled MFMA instruction for the input element types `aType` diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 247dba1..cfdcd9c 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -432,7 +432,7 @@ static Value getOriginalVectorValue(Value value) {                          current = op.getSource();                          return false;                        }) -                      .Default([](Operation *) { return false; }); +                      .Default(false);      if (!skipOp) {        break; diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index 25f1e1b..425594b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -259,7 +259,7 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {            }            return std::nullopt;          }) -        .Default([](auto) { return std::nullopt; }); +        .Default(std::nullopt);    }    static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index a9efada..ec182f1 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -846,13 +846,8 @@ struct NVGPUMBarrierInitLowering      Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),                                     adaptor.getMbarId(), rewriter);      Value count = truncToI32(b, adaptor.getCount()); -    if (isMbarrierShared(mbarrierType)) { -      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>( -          op, barrier, count, adaptor.getPredicate()); -    } else { -      rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, -                                                        adaptor.getPredicate()); -    } +    rewriter.replaceOpWithNewOp<NVVM::MBarrierInitOp>(op, barrier, count, +                                                      adaptor.getPredicate());      return success();    }  }; diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index 7d0a236..76a822b 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -14,6 +14,7 @@  #include "mlir/Conversion/SCFToGPU/SCFToGPU.h" +#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"  #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  #include "mlir/Dialect/Affine/IR/AffineOps.h"  #include "mlir/Dialect/Arith/IR/Arith.h" @@ -27,6 +28,7 @@  #include "mlir/Interfaces/SideEffectInterfaces.h"  #include "mlir/Transforms/DialectConversion.h"  #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/DenseSet.h"  #include "llvm/Support/DebugLog.h"  #include <optional> @@ -625,18 +627,49 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,    bool seenSideeffects = false;    // Whether we have left a nesting scope (and hence are no longer innermost).    bool leftNestingScope = false; +  LocalAliasAnalysis aliasAnalysis; +  llvm::DenseSet<Value> writtenBuffer;    while (!worklist.empty()) {      Operation *op = worklist.pop_back_val();      // Now walk over the body and clone it.      // TODO: This is only correct if there either is no further scf.parallel -    //       nested or this code is side-effect free. Otherwise we might need -    //       predication. We are overly conservative for now and only allow -    //       side-effects in the innermost scope. +    //       nested or this code has side-effect but the memory buffer is not +    //       alias to inner loop access buffer. Otherwise we might need +    //       predication.      if (auto nestedParallel = dyn_cast<ParallelOp>(op)) {        // Before entering a nested scope, make sure there have been no -      // sideeffects until now. -      if (seenSideeffects) -        return failure(); +      // sideeffects until now or the nested operations do not access the +      // buffer written by outer scope. +      if (seenSideeffects) { +        WalkResult walkRes = nestedParallel.walk([&](Operation *nestedOp) { +          if (isMemoryEffectFree(nestedOp)) +            return WalkResult::advance(); + +          auto memEffectInterface = dyn_cast<MemoryEffectOpInterface>(nestedOp); +          if (!memEffectInterface) +            return WalkResult::advance(); + +          SmallVector<MemoryEffects::EffectInstance> effects; +          memEffectInterface.getEffects(effects); +          for (const MemoryEffects::EffectInstance &effect : effects) { +            if (isa<MemoryEffects::Read>(effect.getEffect()) || +                isa<MemoryEffects::Write>(effect.getEffect())) { +              Value baseBuffer = effect.getValue(); +              if (!baseBuffer) +                return WalkResult::interrupt(); +              for (Value val : writtenBuffer) { +                if (aliasAnalysis.alias(baseBuffer, val) != +                    AliasResult::NoAlias) { +                  return WalkResult::interrupt(); +                } +              } +            } +          } +          return WalkResult::advance(); +        }); +        if (walkRes.wasInterrupted()) +          return failure(); +      }        // A nested scf.parallel needs insertion of code to compute indices.        // Insert that now. This will also update the worklist with the loops        // body. @@ -650,6 +683,7 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,        rewriter.setInsertionPointAfter(parent);        leftNestingScope = true;        seenSideeffects = false; +      writtenBuffer.clear();      } else if (auto reduceOp = dyn_cast<scf::ReduceOp>(op)) {        // Convert scf.reduction op        auto parentLoop = op->getParentOfType<ParallelOp>(); @@ -682,6 +716,24 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,        Operation *clone = rewriter.clone(*op, cloningMap);        cloningMap.map(op->getResults(), clone->getResults());        // Check for side effects. +      if (!isMemoryEffectFree(clone)) { +        // Record the buffer accessed by the operations with write effects. +        if (auto memEffectInterface = +                dyn_cast<MemoryEffectOpInterface>(clone)) { +          SmallVector<MemoryEffects::EffectInstance> effects; +          memEffectInterface.getEffects(effects); +          for (const MemoryEffects::EffectInstance &effect : effects) { +            if (isa<MemoryEffects::Write>(effect.getEffect())) { +              Value writtenBase = effect.getValue(); +              // Conservatively return failure if we cannot find the written +              // address. +              if (!writtenBase) +                return failure(); +              writtenBuffer.insert(writtenBase); +            } +          } +        } +      }        // TODO: Handle region side effects properly.        seenSideeffects |=            !isMemoryEffectFree(clone) || clone->getNumRegions() != 0; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 41d8d53..69a317ec 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -716,7 +716,7 @@ lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc,    accumulator = getOrCreateAccumulator<ReductionNeutral>(rewriter, loc,                                                           llvmType, accumulator);    return LLVMRedIntrinOp::create(rewriter, loc, llvmType, -                                 /*startValue=*/accumulator, vectorOperand, +                                 /*start_value=*/accumulator, vectorOperand,                                   fmf);  } @@ -743,7 +743,7 @@ static Value lowerPredicatedReductionWithStartValue(    Value vectorLength =        createVectorLengthValue(rewriter, loc, vectorOperand.getType());    return LLVMVPRedIntrinOp::create(rewriter, loc, llvmType, -                                   /*startValue=*/accumulator, vectorOperand, +                                   /*satrt_value=*/accumulator, vectorOperand,                                     mask, vectorLength);  } diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index e2c7d80..91c1aa5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -46,7 +46,7 @@ static bool isZeroConstant(Value val) {            [](auto floatAttr) { return floatAttr.getValue().isZero(); })        .Case<IntegerAttr>(            [](auto intAttr) { return intAttr.getValue().isZero(); }) -      .Default([](auto) { return false; }); +      .Default(false);  }  static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index e08cc6f..d428fbf 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1106,10 +1106,7 @@ static bool isUniformDefinition(Value value,        return false;    } -  if (!value.getType().isIntOrIndexOrFloat()) -    return false; - -  return true; +  return value.getType().isIntOrIndexOrFloat();  }  /// Generates a broadcast op for the provided uniform value using the diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 898d76c..980442e 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {            .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })            .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })            .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) -          .Default([](Operation *op) { return std::nullopt; }); +          .Default(std::nullopt);    if (!maybeKind) {      return std::nullopt;    } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index d9d6934..8655ed3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,  /// Return the FuncOp called by `callOp`.  static FuncOp getCalledFunction(CallOpInterface callOp,                                  SymbolTableCollection &symbolTables) { -  SymbolRefAttr sym = -      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); -  if (!sym) -    return nullptr; -  return dyn_cast_or_null<FuncOp>( -      symbolTables.lookupNearestSymbolFrom(callOp, sym)); +  return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));  }  /// Return the FuncOp called by `callOp`. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index aa53f94..c233e24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {  static func::FuncOp  getCalledFunction(func::CallOp callOp,                    mlir::SymbolTableCollection &symbolTable) { -  SymbolRefAttr sym = -      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); -  if (!sym) -    return nullptr;    return dyn_cast_or_null<func::FuncOp>( -      symbolTable.lookupNearestSymbolFrom(callOp, sym)); +      callOp.resolveCallableInTable(&symbolTable));  }  /// Return "true" if the given function signature has tensor semantics. diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index d2c2138..025d1ac 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -330,7 +330,7 @@ static Value getBase(Value v) {                v = op.getSrc();                return true;              }) -            .Default([](Operation *) { return false; }); +            .Default(false);      if (!shouldContinue)        break;    } @@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {        .Case([](memref::TransposeOp transpose) { return transpose.getIn(); })        .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(            [](auto op) { return op.getSrc(); }) -      .Default([](Operation *) { return Value(); }); +      .Default(nullptr);  }  /// Returns `true` if the given operation is known to capture the given value, @@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {        // These operations are known not to capture.        .Case([](memref::DeallocOp) { return false; })        // By default, we don't know anything. -      .Default([](Operation *) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// Returns `true` if the value may be captured by any of its users, i.e., if diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 81c3069..ec1571a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -416,13 +416,39 @@ createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op,    if (ci.clusterSize >= 32) {      if (chipset.majorVersion <= 9) {        // Broadcast last value from each row to next row. -      // Use row mask to avoid polluting rows 1 and 3. +      // Use row mask to avoid polluting row 0 (and row 2 if wave-64).        dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res,                                    amdgpu::DPPPerm::row_bcast_15,                                    rewriter.getUnitAttr(), 0xa, allBanks,                                    /*bound_ctrl*/ false);        res = vector::makeArithReduction(            rewriter, loc, gpu::convertReductionKind(mode), res, dpp); + +      // For subgroupSize = 64, at this point lanes [16, 32) contain the full +      // reduction over lanes [0, 32), but lanes [0, 16) do not. Similarly, +      // lanes [48, 64) contain the full reduction over lanes [32, 64), but +      // lanes [32, 48) do not. +      // +      // If subgroup size is 64 and cluster size is 64, we don't need lanes [0, +      // 16) and [32, 48) to have the correct cluster-32 reduction values at +      // this point, because only lane 63's value will ultimately be read in +      // this full-cluster case. +      // +      // If subgroup size is 64 and cluster size is 32, we need to ensure that +      // lanes [0, 16) and [32, 48) have the correct final cluster-32 reduction +      // values (subgroup_reduce guarantees that all lanes within each cluster +      // contain the final reduction value). We do this by broadcasting lane +      // 31's value to lanes [0, 16) and lanes 63's value to lanes [32, 48). +      // +      // See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations +      // for an illustration of how this within-cluster broadcast works with a +      // swizzle. +      if (ci.subgroupSize == 64 && ci.clusterSize == 32) { +        res = +            amdgpu::SwizzleBitModeOp::create(rewriter, loc, res, /*and_mask=*/0, +                                             /*or_mask=*/31, +                                             /*xor_mask=*/0); +      }      } else if (chipset.majorVersion <= 12) {        // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2).        Value uint32Max = arith::ConstantOp::create( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3eae67f..2731069 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,                         return structType.getBody()[memberIndex];                       return nullptr;                     }) -                   .Default(Type(nullptr)); +                   .Default(nullptr);    }  } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index cee943d..7d9058c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,            .Case<IntegerType, FloatType>([](auto type) {              return type.getWidth() % 8 == 0 && type.getWidth() > 0;            }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!canConvertType)      return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ac35eea..ce93d18 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {            // clang-format on            .Case<PtrLikeTypeInterface>(                [](Type type) { return isCompatiblePtrType(type); }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!result)      compatibleTypes.erase(type); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4db..a5ffb9e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,    } else if (type == NVVM::MMATypes::f32) {      elementType = builder.getF32Type();      numberElements = 8; +  } else if (type == NVVM::MMATypes::f64) { +    elementType = builder.getF64Type(); +    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) +      numberElements = 1; +    else +      numberElements = 2;    } else if (type == NVVM::MMATypes::tf32) {      elementType = builder.getI32Type();      numberElements = 4; @@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {      return emitOpError() << "invalid attribute combination";    std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(        getEltype(), getFrag(), getM(), getN(), getK(), getContext()); +  // Special case for f64 fragments +  Type f64Ty = Float64Type::get(getContext()); +  if (typeInfo.first == f64Ty && typeInfo.second == 1) { +    if (getType() != f64Ty) +      return emitOpError("expected destination type to be f64"); +    return success(); +  } +  // Everything else is a struct    Type dstType = LLVM::LLVMStructType::getLiteral(        getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));    if (getType() != dstType) @@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,  }  //===----------------------------------------------------------------------===// +// getPtx methods +//===----------------------------------------------------------------------===// + +std::string NVVM::MBarrierInitOp::getPtx() { +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); +  return (addressSpace == NVVMMemorySpace::Shared) +             ? std::string("mbarrier.init.shared.b64 [%0], %1;") +             : std::string("mbarrier.init.b64 [%0], %1;"); +} + +//===----------------------------------------------------------------------===//  // getIntrinsicID/getIntrinsicIDAndArgs methods  //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( +    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { +  auto thisOp = cast<NVVM::MBarrierInitOp>(op); +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) +          .getAddressSpace(); +  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) +                               ? llvm::Intrinsic::nvvm_mbarrier_init_shared +                               : llvm::Intrinsic::nvvm_mbarrier_init; + +  // Fill the Intrinsic Args +  llvm::SmallVector<llvm::Value *> args; +  args.push_back(mt.lookupValue(thisOp.getAddr())); +  args.push_back(mt.lookupValue(thisOp.getCount())); + +  return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( +    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { +  auto thisOp = cast<NVVM::MBarrierInvalOp>(op); +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) +          .getAddressSpace(); +  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) +                               ? llvm::Intrinsic::nvvm_mbarrier_inval_shared +                               : llvm::Intrinsic::nvvm_mbarrier_inval; + +  return {id, {mt.lookupValue(thisOp.getAddr())}}; +} +  #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \    llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,                                       OpAsmSetValueNameFn setNameFn) {    for (Value v : getRegionInputArgs())      setNameFn(v, "in"); +  for (Value v : getRegionOutputArgs()) +    setNameFn(v, "init");  }  void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build(    if (bodyBuild)      buildGenericRegion(builder, result.location, *result.regions.front(), -                       inputs, /*outputs=*/{}, bodyBuild); +                       inputs, /*outputs=*/{init}, bodyBuild);  }  static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,                                   const OperationName &payloadOpName,                                   const NamedAttrList &payloadOpAttrs,                                   ArrayRef<Value> operands, -                                 bool initFirst = false) { +                                 bool initFirst = false, bool mapInit = true) {    OpBuilder b(parser.getContext());    Region *body = result.addRegion();    Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,    // If initFirst flag is enabled, we consider init as the first position of    // payload operands.    if (initFirst) { -    payloadOpOperands.push_back(block.getArguments().back()); +    if (mapInit) +      payloadOpOperands.push_back(block.getArguments().back());      for (const auto &arg : block.getArguments().drop_back())        payloadOpOperands.push_back(arg);    } else {      payloadOpOperands = {block.getArguments().begin(), -                         block.getArguments().end()}; +                         block.getArguments().end() - int(!mapInit)};    }    Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    if (payloadOpName.has_value()) {      if (!result.operands.empty())        addBodyWithPayloadOp(parser, result, payloadOpName.value(), -                           payloadOpAttrs, -                           ArrayRef(result.operands).drop_back()); +                           payloadOpAttrs, ArrayRef(result.operands), false, +                           false);      else        result.addRegion();    } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    return success();  } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, +                            bool mapInit = true) { +  // `intFirst == true` implies that we want to map init arg +  if (initFirst && !mapInit) +    return false;    // Check if the body can be printed in short form. The following 4 conditions    // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {    // 2) The payload op must have the same number of operands as the number of    //    block arguments.    if (payload.getNumOperands() == 0 || -      payload.getNumOperands() != body->getNumArguments()) +      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))      return false;    // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {      }    } else {      for (const auto &[operand, bbArg] : -         llvm::zip(payload.getOperands(), body->getArguments())) { +         llvm::zip(payload.getOperands(), +                   body->getArguments().drop_back(int(!mapInit)))) {        if (bbArg != operand)          return false;      } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {  void MapOp::print(OpAsmPrinter &p) {    Block *mapper = getBody(); -  bool useShortForm = canUseShortForm(mapper); +  bool useShortForm = +      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);    if (useShortForm) {      printShortForm(p, &mapper->getOperations().front());    } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {    auto *bodyBlock = getBody();    auto blockArgs = bodyBlock->getArguments(); -  // Checks if the number of `inputs` match the arity of the `mapper` region. -  if (getInputs().size() != blockArgs.size()) +  // Checks if the number of `inputs` + `init` match the arity of the `mapper` +  // region. +  if (getInputs().size() + 1 != blockArgs.size())      return emitOpError() << "expects number of operands to match the arity of "                              "mapper, but got: " -                         << getInputs().size() << " and " << blockArgs.size(); +                         << getInputs().size() + 1 << " and " +                         << blockArgs.size();    // The parameters of mapper should all match the element type of inputs.    for (const auto &[bbArgType, inputArg] : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8b89244..3a43382 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1958,7 +1958,7 @@ enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };  /// Return true if either `op` or `permutation` are empty to allow a simpler  /// polymorphic implementation.  template <typename RelayoutOpTy> -bool isValidPackingPermutation( +static bool isValidPackingPermutation(      RelayoutOpTy op, ArrayRef<int64_t> permutation,      OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {    static_assert( @@ -4322,9 +4322,10 @@ DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(  // InsertSliceToCopyOp  //===----------------------------------------------------------------------===//  template <typename OpTy> -DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, -                                 transform::ApplyToEachResultList &results, -                                 transform::TransformState &state) { +static DiagnosedSilenceableFailure +doit(RewriterBase &rewriter, OpTy target, +     transform::ApplyToEachResultList &results, +     transform::TransformState &state) {    static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,                                  tensor::ParallelInsertSliceOp>() &&                  "wrong op type"); @@ -4499,7 +4500,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(              maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);              return true;            }) -          .Default([&](Operation *op) { return false; }); +          .Default(false);    if (!supported) {      DiagnosedSilenceableFailure diag = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 3e31393..75bb175 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -31,10 +31,8 @@ using namespace mlir;  using namespace mlir::linalg;  static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { -  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot -  // trivially generalize a `linalg.map`, as it does not use the output as -  // region arguments in the block. -  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) +  // Bailout if `linalgOp` is already a generic. +  if (isa<GenericOp>(linalgOp))      return failure();    // Check if the operation has exactly one region.    if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index f05ffa8..6519c4f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,                  tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));                return complex::CreateOp::create(b, t, tmp, tmp);              }) -            .Default([](auto) { return Value(); }); +            .Default(nullptr);      if (!fillVal)        return failure();      linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 27ccf3c..6becc1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,                  ValueRange{input, collapsedKernel, iZp, kZp},                  ValueRange{collapsedInit}, stride, dilation);            }) -          .Default([](Operation *op) { return nullptr; }); +          .Default(nullptr);    if (!newConv)      return failure();    for (auto attr : preservedAttrs) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0f317ea..cb6199f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {            [&](auto op) { return CombiningKind::MUL; })        .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })        .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) -      .Default([&](auto op) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// Check whether `outputOperand` is a reduction with a single combiner diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 1208fdd..e685089 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {                       vector::MaskedStoreOp, vector::TransferReadOp,                       vector::TransferWriteOp>(            [](auto op) { return op.getBase(); }) -      .Default([](auto) { return Value{}; }); +      .Default(nullptr);  }  template <typename T> diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 660c313..fbac28e 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -145,3 +145,13 @@ std::string mlir::acc::getRecipeName(mlir::acc::RecipeKind kind,    return recipeName;  } + +mlir::Value mlir::acc::getBaseEntity(mlir::Value val) { +  if (auto partialEntityAccessOp = +          dyn_cast<PartialEntityAccessOpInterface>(val.getDefiningOp())) { +    if (!partialEntityAccessOp.isCompleteView()) +      return partialEntityAccessOp.getBaseEntity(); +  } + +  return val; +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 4ebd90d..d380c46 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {                               ? forOp.getInitArgs()[opResult.getResultNumber()]                               : Value();                  }) -                .Default([&](auto op) { return Value(); }); +                .Default(nullptr);    }    return false;  } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 0c8114d..938952e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {        llvm::TypeSwitch<Type, Type>(getType())            .Case<spirv::CooperativeMatrixType>(                [](auto coopType) { return coopType.getElementType(); }) -          .Default([](Type) { return nullptr; }); +          .Default(nullptr);    // Case 1. -- matrices.    if (coopElementType) { @@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {        llvm::TypeSwitch<Type, Type>(getMatrix().getType())            .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(                [](auto matrixType) { return matrixType.getElementType(); }) -          .Default([](Type) { return nullptr; }); +          .Default(nullptr);    assert(elementType && "Unhandled type"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index f895807..d1e275d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {            return *elementSize * type.getNumElements();          return std::nullopt;        }) -      .Default(std::optional<int64_t>()); +      .Default(std::nullopt);  }  //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 88e1ab6..cb9b7f6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {    return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)        .Case<vector::ReductionOp, vector::TransposeOp>(            [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) -      .Default([](Operation *) { return std::nullopt; }); +      .Default(std::nullopt);  }  LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index ac72002..110bfdc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -41,10 +41,6 @@  using namespace mlir;  using namespace mlir::tensor; -using llvm::divideCeilSigned; -using llvm::divideFloorSigned; -using llvm::mod; -  /// Materialize a single constant operation from a given attribute value with  /// the desired resultant type.  Operation *TensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bce964e..c607ece 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,        linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),                              /*init=*/tensorDestination);    Block &linalgBody = linalgOp.getMapper().emplaceBlock(); +  linalgBody.addArgument(tensorType.getElementType(), loc);    // Create linalg::IndexOps.    rewriter.setInsertionPointToStart(&linalgBody); @@ -1068,6 +1069,7 @@ struct SplatOpInterface                                            /*inputs=*/ValueRange(),                                            /*init=*/*tensorAlloc);      Block &linalgBody = linalgOp.getMapper().emplaceBlock(); +    linalgBody.addArgument(tensorType.getElementType(), loc);      // Create linalg::IndexOps.      rewriter.setInsertionPointToStart(&linalgBody); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp index 69e649d..bc4f5a5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {                return constantFoldPadOp<llvm::APInt>(                    rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);              }) -            .Default(Value()); +            .Default(nullptr);      if (!newOp)        return rewriter.notifyMatchFailure(padTensorOp, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ad8255a..ae3423c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4336,7 +4336,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {    // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.    if (auto splat =            llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) -    DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>()); +    return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());    // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.    return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource()); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index f9aa28d5..83406c8 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -11,7 +11,6 @@  #include "mlir/Dialect/Index/IR/IndexOps.h"  #include "mlir/Dialect/Utils/IndexingUtils.h"  #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"  #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"  #include "mlir/IR/Builders.h"  #include "mlir/IR/DialectImplementation.h" @@ -229,8 +228,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,    }    if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) { -    return emitError() -           << "expected inst_data and lane_layout to have the same rank"; +    return emitError() << "expected inst_data and lane_layout to have the same " +                          "rank, got inst_data " +                       << inst_data.size() << ", lane_layout " +                       << lane_layout.size();    }    // sg_data is optional for Workgroup layout, but its presence requires @@ -569,8 +570,8 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,    // for gather and scatter ops, Low-precision types are packed in 32-bit units.    unsigned bitWidth = elementType.getIntOrFloatBitWidth();    int chunkAlignmentFactor = -      bitWidth < targetinfo::packedSizeInBitsForGatherScatter -          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth +      bitWidth < xegpu::uArch::generalPackedFormatBitSize +          ? xegpu::uArch::generalPackedFormatBitSize / bitWidth            : 1;    auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);    if (scatterAttr) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 8fab255..90eae87 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -14,7 +14,6 @@  #include "mlir/Dialect/MemRef/IR/MemRef.h"  #include "mlir/Dialect/Vector/IR/VectorOps.h"  #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"  #include "mlir/Dialect/XeGPU/Transforms/Passes.h"  #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"  #include "mlir/IR/Attributes.h" @@ -37,6 +36,8 @@  #include "llvm/Support/LogicalResult.h"  #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +  namespace mlir {  namespace xegpu {  #define GEN_PASS_DEF_XEGPUPROPAGATELAYOUT @@ -104,6 +105,8 @@ public:    SmallVector<int> getLaneData() const; +  SmallVector<int> getInstData() const; +    bool isSliceLayout() const {      if (!isAssigned())        return false; @@ -137,6 +140,13 @@ SmallVector<int> LayoutInfo::getLaneData() const {                               [](int64_t val) { return static_cast<int>(val); });  } +SmallVector<int> LayoutInfo::getInstData() const { +  if (!isAssigned()) +    return {}; +  return llvm::map_to_vector(storage.getEffectiveInstDataAsInt(), +                             [](int64_t val) { return static_cast<int>(val); }); +} +  void LayoutInfo::print(raw_ostream &os) const {    if (isAssigned()) {      os << storage; @@ -174,12 +184,14 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {    SmallVector<int32_t> laneLayout;    SmallVector<int32_t> laneData; +  SmallVector<int32_t> instData;    for (int64_t idx : permutation) {      laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));      laneData.push_back(static_cast<int32_t>(getLaneData()[idx])); +    instData.push_back(static_cast<int32_t>(getInstData()[idx]));    } -  return LayoutInfo( -      xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData)); +  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData, +                                           laneLayout, laneData));  }  //===----------------------------------------------------------------------===// @@ -192,6 +204,28 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {    using Lattice::Lattice;  }; +/// Helper Function to find a proper instruction multiple for the user-supplied +/// sg-level data shape. `candidates` are uArch allowed shapes. +/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count). +template <typename T> +int getLargestDivisor(T dim, ArrayRef<T> candidates, +                      ArrayRef<T> candidateMultiples = {}) { +  static_assert(std::is_integral<T>::value, "T must be an integer type"); +  int largest = -1; +  SmallVector<T> multiples = {1}; +  if (!candidateMultiples.empty()) +    multiples = +        SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end()); +  for (T candidate : candidates) { +    for (T multiple : multiples) { +      int value = static_cast<int>(candidate * multiple); +      if (value != 0 && dim % value == 0 && value > largest) +        largest = value; +    } +  } +  return largest; +} +  /// Helper Functions to get default layouts. A `default layout` is a layout that  /// is assigned to a value when the layout is not fixed by some anchor operation  /// (like DPAS). @@ -200,18 +234,32 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {  /// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].  /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].  static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, -                                           unsigned rank) { +                                           unsigned rank, +                                           const xegpu::uArch::uArch *uArch, +                                           ArrayRef<int> instData) {    assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");    if (rank == 1) {      return LayoutInfo( -        xegpu::LayoutAttr::get(ctx, {xegpu::targetinfo::subgroupSize}, {1})); +        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));    }    return LayoutInfo(xegpu::LayoutAttr::get( -      ctx, {1, xegpu::targetinfo::subgroupSize}, {1, 1})); +      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1})); +} + +static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, +                                           unsigned rank, int subgroupSize) { +  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); +  if (rank == 1) { +    return LayoutInfo(xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1})); +  } +  return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));  }  /// Helper to get the default layout for a vector type.  static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, +                                           const xegpu::uArch::uArch *uArch, +                                           ArrayRef<int> instData, +                                           unsigned packingSize,                                             bool isScattered = false) {    // Expecting a 1D or 2D vector.    assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && @@ -221,28 +269,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,           "Expected int or float element type.");    // If the rank is 1, then return default layout for 1D vector.    if (vectorTy.getRank() == 1) -    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1); +    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);    // Packing factor is determined by the element type bitwidth. -  int packingFactor = 1;    unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); +  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;    if (isScattered) { -    packingFactor = -        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter -            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth -            : 1; -    return LayoutInfo(xegpu::LayoutAttr::get( -        vectorTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, -        {1, packingFactor})); +    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, +                                             {uArch->getSubgroupSize(), 1}, +                                             {1, packingFactor}));    } -  if (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) -    packingFactor = xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth; -  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), -                                           {1, xegpu::targetinfo::subgroupSize}, +  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, +                                           {1, uArch->getSubgroupSize()},                                             {1, packingFactor}));  }  /// Helper to get the default layout for a vector type.  static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, +                                           const xegpu::uArch::uArch *uArch, +                                           ArrayRef<int> instData, +                                           unsigned packingSize,                                             bool isScattered = false) {    // Expecting a 1D or 2D vector.    assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) && @@ -252,27 +297,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,           "Expected int or float element type.");    // If the rank is 1, then return default layout for 1D vector.    if (tdescTy.getRank() == 1) -    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1); +    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);    // Packing factor is determined by the element type bitwidth.    unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); - +  int subgroupSize = uArch->getSubgroupSize(); +  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;    if (isScattered) { -    int packingFactor = -        bitwidth < xegpu::targetinfo::packedSizeInBitsForGatherScatter -            ? xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth -            : 1;      return LayoutInfo(xegpu::LayoutAttr::get( -        tdescTy.getContext(), {xegpu::targetinfo::subgroupSize, 1}, -        {1, packingFactor})); +        tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));    } -  int packingFactor = -      (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) -          ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth -          : 1; -  return LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), -                                           {1, xegpu::targetinfo::subgroupSize}, -                                           {1, packingFactor})); +  return LayoutInfo(xegpu::LayoutAttr::get( +      tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));  }  /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -281,25 +317,25 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,  /// `packedSizeInBitsForDefault`  /// * For B operand, the data must be packed in minimum  /// `packedSizeInBitsForDpasB` -static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, -                                                  unsigned operandNum) { +static LayoutInfo +getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, +                                const xegpu::uArch::uArch *uArch, +                                ArrayRef<int> instData, unsigned packingSize) {    Type elementTy = vectorTy.getElementType();    assert(elementTy.isIntOrFloat() &&           "Expected int or float type in DPAS operands"); -  SmallVector<int32_t, 2> layout({1, xegpu::targetinfo::subgroupSize}); +  SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});    // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and    // must have the VNNI format. -  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < -                             xegpu::targetinfo::packedSizeInBitsForDpasB) { +  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {      SmallVector<int32_t, 2> data( -        {static_cast<int32_t>(xegpu::targetinfo::packedSizeInBitsForDpasB / -                              elementTy.getIntOrFloatBitWidth()), +        {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),           1});      return LayoutInfo( -        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); +        xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));    }    // Otherwise, return the default layout for the vector type. -  return getDefaultSIMTLayoutInfo(vectorTy); +  return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);  }  //===----------------------------------------------------------------------===// @@ -456,7 +492,37 @@ void LayoutInfoPropagation::visitPrefetchNdOp(    // Here we assign the default layout to the tensor descriptor operand of    // prefetch.    auto tdescTy = prefetch.getTensorDescType(); -  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy); + +  auto uArch = getUArch(getChipStr(prefetch).value_or("")); +  const auto *uArchInstruction = +      dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>( +          uArch->getInstruction( +              xegpu::uArch::InstructionKind::Subgroup2DBlockPrefetch)); + +  auto blockWHC = +      uArchInstruction->getBlockWidthHeightCount(tdescTy.getElementType()); +  if (!blockWHC) +    prefetch.emitWarning("No known block params found for the element type."); +  auto [bWidth, bHeight, bCount] = blockWHC.value(); +  SmallVector<int> instData; +  int instWidth = getLargestDivisor( +      static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, +      bCount); +  if (instWidth == -1) +    prefetch.emitWarning( +        "No suitable instruction multiple found for the given shape."); +  if (tdescTy.getRank() == 1) +    instData = {instWidth}; +  else { +    int instHeight = getLargestDivisor( +        static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); +    if (instHeight == -1) +      prefetch.emitWarning( +          "No suitable instruction multiple found for the given shape."); +    instData = {instHeight, instWidth}; +  } +  auto prefetchLayout = getDefaultSIMTLayoutInfo( +      tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());    // Propagate the layout to the source tensor descriptor.    propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));  } @@ -475,10 +541,11 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(      reduction.emitWarning("Expecting output type to be 1D vector.");      return;    } +  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));    // Given that the result is 1D, the layout of the operand should be 2D with    // default layout. -  LayoutInfo operandLayout = -      getDefaultSIMTLayoutInfo(reduction->getContext(), 2); +  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo( +      reduction->getContext(), 2, uArch->getSubgroupSize());    propagateIfChanged(operands[0], operands[0]->meet(operandLayout));    // Accumulator should have the same layout as the result.    propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); @@ -557,15 +624,53 @@ void LayoutInfoPropagation::visitDpasOp(      ArrayRef<const LayoutInfoLattice *> results) {    VectorType aTy = dpas.getLhsType();    VectorType bTy = dpas.getRhsType(); -  propagateIfChanged( -      operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0))); -  propagateIfChanged( -      operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1))); + +  auto uArch = getUArch(getChipStr(dpas).value_or("")); +  const int subgroupSize = uArch->getSubgroupSize(); +  const auto *uArchInstruction = +      dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction( +          xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)); + +  const unsigned dataALen = aTy.getShape().front(); +  auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); +  const int maxALen = +      getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen)); +  if (maxALen == -1) +    dpas.emitWarning( +        "No suitable instruction multiple found for the given shape."); + +  const unsigned dataBLen = bTy.getShape().back(); +  auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); +  const int maxBLen = +      getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen)); +  if (maxBLen == -1) +    dpas.emitWarning( +        "No suitable instruction multiple found for the given shape."); +  SmallVector<int> instDataA = {maxALen, subgroupSize}; +  SmallVector<int> instDataB = {subgroupSize, maxBLen}; + +  propagateIfChanged(operands[0], +                     operands[0]->meet(getSIMTLayoutInfoForDPASOperand( +                         aTy, 0, uArch, instDataA, +                         uArchInstruction->getPackedFormatBitSizeA()))); +  propagateIfChanged(operands[1], +                     operands[1]->meet(getSIMTLayoutInfoForDPASOperand( +                         bTy, 1, uArch, instDataB, +                         uArchInstruction->getPackedFormatBitSizeB())));    if (operands.size() > 2) {      VectorType cTy = dpas.getAccType(); -    propagateIfChanged( -        operands[2], -        operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2))); +    const unsigned dataCLen = bTy.getShape().back(); +    auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); +    const int maxCLen = +        getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen)); +    if (maxCLen == -1) +      dpas.emitWarning( +          "No suitable instruction multiple found for the given shape."); +    SmallVector<int> instDataC = {maxALen, maxCLen}; +    propagateIfChanged(operands[2], +                       operands[2]->meet(getSIMTLayoutInfoForDPASOperand( +                           cTy, 2, uArch, instDataC, +                           uArchInstruction->getPackedFormatBitSizeB())));    }  } @@ -573,7 +678,38 @@ void LayoutInfoPropagation::visitDpasOp(  void LayoutInfoPropagation::visitStoreNdOp(      xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,      ArrayRef<const LayoutInfoLattice *> results) { -  LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType()); + +  auto uArch = getUArch(getChipStr(store).value_or("")); +  const auto *uArchInstruction = +      dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>( +          uArch->getInstruction( +              xegpu::uArch::InstructionKind::Subgroup2DBlockStore)); +  VectorType dataTy = store.getValueType(); +  auto blockWHC = uArchInstruction->getBlockWidthHeightCount( +      store.getValueType().getElementType()); +  if (!blockWHC) +    store.emitWarning("No known block params found for the element type."); +  auto [bWidth, bHeight, bCount] = blockWHC.value(); +  SmallVector<int> instData; +  int instWidth = getLargestDivisor( +      static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, +      bCount); +  if (instWidth == -1) +    store.emitWarning( +        "No suitable instruction multiple found for the given shape."); +  if (dataTy.getRank() == 1) +    instData = {instWidth}; +  else { +    int instHeight = getLargestDivisor( +        static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); +    if (instHeight == -1) +      store.emitWarning( +          "No suitable instruction multiple found for the given shape."); +    instData = {instHeight, instWidth}; +  } +  LayoutInfo storeLayout = +      getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData, +                               uArchInstruction->getPackedFormatBitSize());    // Both operands should have the same layout    for (LayoutInfoLattice *operand : operands)      propagateIfChanged(operand, operand->meet(storeLayout)); @@ -694,10 +830,23 @@ void LayoutInfoPropagation::visitLoadGatherOp(      load.emitWarning("Not propagating, non-vector payload supplied.");      return;    } -  LayoutInfo layout = getDefaultSIMTLayoutInfo(payloadTy, /*scattered*/ true); +  auto uArch = getUArch(getChipStr(load).value_or("")); +  const int subgroupSize = uArch->getSubgroupSize(); +  SmallVector<int> instData{subgroupSize}; +  if (auto chunkSize = load.getChunkSize().value_or(0); chunkSize > 1) +    instData.push_back(chunkSize); +  else if (auto srcTdescTy = +               dyn_cast<xegpu::TensorDescType>(load.getSourceType())) { +    if (srcTdescTy.getChunkSizeAsInt() > 1) +      instData.push_back(chunkSize); +  } +  LayoutInfo layout = getDefaultSIMTLayoutInfo( +      payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), +      /*scattered*/ true);    // Mask operand should have 1D default layout. -  LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1); +  LayoutInfo maskLayout = +      getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);    // Propagate the new layout to the tensor descriptor operand.    if (isa<xegpu::TensorDescType>(load.getSourceType())) @@ -717,8 +866,10 @@ void LayoutInfoPropagation::visitCreateDescOp(    // Need the layout of the descriptor to propagate to the operands.    if (!descLayout.isAssigned())      return; +  auto uArch = getUArch(getChipStr(createDesc).value_or(""));    // For offset operand propagate 1D default layout. -  LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1); +  LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1, +                                               uArch->getSubgroupSize());    propagateIfChanged(operands[1], operands[1]->meet(layout));  } @@ -735,18 +886,30 @@ void LayoutInfoPropagation::visitStoreScatterOp(      storeScatter.emitWarning("Not propagating, non-vector payload supplied.");      return;    } +  auto uArch = getUArch(getChipStr(storeScatter).value_or("")); +  const int subgroupSize = uArch->getSubgroupSize(); +    auto payloadShape = payloadTy.getShape();    if (payloadShape.size() > 1)      assert( -        payloadShape[0] == xegpu::targetinfo::subgroupSize && +        payloadShape[0] == subgroupSize &&          "Expected the first dimension of 2D tensor descriptor to be equal to "          "subgroup size."); -  LayoutInfo payloadLayout = -      getDefaultSIMTLayoutInfo(payloadTy, /*scattered=*/true); +  SmallVector<int> instData{subgroupSize}; +  if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) +    instData.push_back(chunkSize); +  else if (auto dstTdescTy = +               dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) { +    if (dstTdescTy.getChunkSizeAsInt() > 1) +      instData.push_back(chunkSize); +  } +  LayoutInfo payloadLayout = getDefaultSIMTLayoutInfo( +      payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), +      /*scattered=*/true);    LayoutInfo maskLayout = -      getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1); +      getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);    // Propagate the payload operand layout    propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));    // Propagate the destination (if tdesc) operand layout @@ -1023,9 +1186,13 @@ void XeGPUPropagateLayoutPass::runOnOperation() {      LayoutInfo layout = analysis.getLayoutInfo(val);      if (!layout.isAssigned())        return {}; +    xegpu::DistributeLayoutAttr layoutAttr = +        cast<xegpu::DistributeLayoutAttr>(layout.get()); +    if (this->layoutKind == "lane") +      layoutAttr = layoutAttr.dropInstData();      if (layout.isSliceLayout()) -      return cast<xegpu::SliceAttr>(layout.get()); -    return cast<xegpu::LayoutAttr>(layout.get()); +      return cast<xegpu::SliceAttr>(layoutAttr); +    return cast<xegpu::LayoutAttr>(layoutAttr);    };    mlir::OpBuilder builder(&getContext()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index d09dc19..5a3b27e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -11,10 +11,10 @@  #include "mlir/Dialect/Vector/IR/VectorOps.h"  #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"  #include "mlir/Dialect/XeGPU/IR/XeGPU.h" -#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"  #include "mlir/Dialect/XeGPU/Transforms/Passes.h"  #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"  #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"  #include "mlir/IR/AffineMap.h"  #include "mlir/IR/Attributes.h"  #include "mlir/IR/Builders.h" @@ -159,17 +159,18 @@ static bool requirePacked(const xegpu::LayoutAttr layout) {  /// Helper function to check if the layout requires a transpose effect.  static bool requireTranspose(const xegpu::LayoutAttr layout, -                             const std::string &chipStr) { +                             const xegpu::uArch::uArch *uArch) {    // Return false for unsupported targets.    // TODO: Add more support or move to target info. -  if (chipStr != "pvc" && chipStr != "bmg") +  if (uArch->getName().equals_insensitive("pvc") && +      uArch->getName().equals_insensitive("bmg"))      return false;    if (!layout)      return false;    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();    if (laneLayout.size() != 2)      return false; -  return laneLayout[0] == xegpu::targetinfo::subgroupSize && laneLayout[1] == 1; +  return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;  }  /// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body @@ -199,6 +200,11 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {    using OpRewritePattern<gpu::GPUFuncOp>::OpRewritePattern;    LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,                                  PatternRewriter &rewriter) const override { +    auto uArch = getUArch(xegpu::getChipStr(gpuFuncOp).value_or("")); +    if (!uArch) +      return rewriter.notifyMatchFailure( +          gpuFuncOp, "Subgroup distribution requires target attribute attached " +                     "to set the warp size");      // If the function only contains a single void return, skip.      if (llvm::all_of(gpuFuncOp.getBody().getOps(), [](Operation &op) {            return isa<gpu::ReturnOp>(op) && !op.getNumOperands(); @@ -230,7 +236,7 @@ struct MoveFuncBodyToWarpOp : public OpRewritePattern<gpu::GPUFuncOp> {      ArrayRef<Type> gpuFuncResultType = gpuFuncOp.getFunctionType().getResults();      auto warpOp = gpu::WarpExecuteOnLane0Op::create(          rewriter, laneId.getLoc(), gpuFuncResultType, laneId, -        xegpu::targetinfo::subgroupSize, newGpuFunc.getArguments(), +        uArch->getSubgroupSize(), newGpuFunc.getArguments(),          newGpuFunc.getArgumentTypes());      Block &warpBodyBlock = warpOp.getBodyRegion().front();      // Replace the ReturnOp of the original gpu function with a YieldOp. @@ -495,14 +501,14 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {            warpOp, "warp result is not a xegpu::LoadNd op");      auto loadOp = operand->get().getDefiningOp<xegpu::LoadNdOp>(); +    auto uArch = getUArch(xegpu::getChipStr(loadOp).value_or("")); +    if (!uArch) +      return rewriter.notifyMatchFailure( +          loadOp, "xegpu::LoadNdOp require target attribute attached to " +                  "determine transpose " +                  "requirement");      // Chip information is required to decide if the layout requires transpose      // effect. -    auto chipStr = xegpu::getChipStr(loadOp); -    if (!chipStr) -      return rewriter.notifyMatchFailure( -          loadOp, -          "xegpu::LoadNdOp require chip information to determine transpose " -          "requirement");      // Expecting offsets to be present.      SmallVector<OpFoldResult> offsets = loadOp.getMixedOffsets();      if (offsets.empty()) @@ -556,7 +562,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {      // Set the packed attribute if the layout requires it.      newLoadOp.setPacked(requirePacked(layout));      // Set the transpose attribute if the layout requires it. -    if (requireTranspose(layout, chipStr.value())) +    if (requireTranspose(layout, uArch))        newLoadOp.setTranspose(            DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));      Value distributedVal = newWarpOp.getResult(operandIdx); diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index b31377e..0f1bf83 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -56,7 +56,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {          StringRef value = init->getValue();          return value.empty() ? std::optional<StringRef>() : value;        }) -      .Default([](auto *) { return std::nullopt; }); +      .Default(std::nullopt);  }  // Return the C++ type for this type (which may just be ::mlir::Type). diff --git a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp index eeb8725..e3bcf27 100644 --- a/mlir/lib/Target/LLVMIR/DebugTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/DebugTranslation.cpp @@ -390,7 +390,7 @@ llvm::DISubrange *DebugTranslation::translateImpl(DISubrangeAttr attr) {              .Case<>([&](LLVM::DIGlobalVariableAttr global) {                return translate(global);              }) -            .Default([&](Attribute attr) { return nullptr; }); +            .Default(nullptr);      return metadata;    };    return llvm::DISubrange::get(llvmCtx, getMetadataOrNull(attr.getCount()), @@ -420,10 +420,10 @@ DebugTranslation::translateImpl(DIGenericSubrangeAttr attr) {              .Case([&](LLVM::DILocalVariableAttr local) {                return translate(local);              }) -            .Case<>([&](LLVM::DIGlobalVariableAttr global) { +            .Case([&](LLVM::DIGlobalVariableAttr global) {                return translate(global);              }) -            .Default([&](Attribute attr) { return nullptr; }); +            .Default(nullptr);      return metadata;    };    return llvm::DIGenericSubrange::get(llvmCtx, diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index f284540..8edec99 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -4084,12 +4084,13 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,  ///  /// Fortran  ///     map(tofrom: array(2:5, 3:2)) -///   or -/// C++ -///   map(tofrom: array[1:4][2:3]) +///  /// We must calculate the initial pointer offset to pass across, this function  /// performs this using bounds.  /// +/// TODO/WARNING: This only supports Fortran's column major indexing currently +/// as is noted in the note below and comments in the function, we must extend +/// this function when we add a C++ frontend.  /// NOTE: which while specified in row-major order it currently needs to be  /// flipped for Fortran's column order array allocation and access (as  /// opposed to C++'s row-major, hence the backwards processing where order is @@ -4125,46 +4126,28 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,      // with a pointer that's being treated like an array and we have the      // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base      // address (pointer pointing to the actual data) so we must caclulate the -    // offset using a single index which the following two loops attempts to -    // compute. - -    // Calculates the size offset we need to make per row e.g. first row or -    // column only needs to be offset by one, but the next would have to be -    // the previous row/column offset multiplied by the extent of current row. +    // offset using a single index which the following loop attempts to +    // compute using the standard column-major algorithm e.g for a 3D array:      // -    // For example ([1][10][100]): +    // ((((c_idx * b_len) + b_idx) * a_len) + a_idx)      // -    //  - First row/column we move by 1 for each index increment -    //  - Second row/column we move by 1 (first row/column) * 10 (extent/size of -    //  current) for 10 for each index increment -    //  - Third row/column we would move by 10 (second row/column) * -    //  (extent/size of current) 100 for 1000 for each index increment -    std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)}; -    for (size_t i = 1; i < bounds.size(); ++i) { -      if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>( -              bounds[i].getDefiningOp())) { -        dimensionIndexSizeOffset.push_back(builder.CreateMul( -            moduleTranslation.lookupValue(boundOp.getExtent()), -            dimensionIndexSizeOffset[i - 1])); -      } -    } - -    // Now that we have calculated how much we move by per index, we must -    // multiply each lower bound offset in indexes by the size offset we -    // have calculated in the previous and accumulate the results to get -    // our final resulting offset. +    // It is of note that it's doing column-major rather than row-major at the +    // moment, but having a way for the frontend to indicate which major format +    // to use or standardizing/canonicalizing the order of the bounds to compute +    // the offset may be useful in the future when there's other frontends with +    // different formats. +    std::vector<llvm::Value *> dimensionIndexSizeOffset;      for (int i = bounds.size() - 1; i >= 0; --i) {        if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(                bounds[i].getDefiningOp())) { -        if (idx.empty()) -          idx.emplace_back(builder.CreateMul( -              moduleTranslation.lookupValue(boundOp.getLowerBound()), -              dimensionIndexSizeOffset[i])); +        if (i == ((int)bounds.size() - 1)) +          idx.emplace_back( +              moduleTranslation.lookupValue(boundOp.getLowerBound()));          else            idx.back() = builder.CreateAdd( -              idx.back(), builder.CreateMul(moduleTranslation.lookupValue( -                                                boundOp.getLowerBound()), -                                            dimensionIndexSizeOffset[i])); +              builder.CreateMul(idx.back(), moduleTranslation.lookupValue( +                                                boundOp.getExtent())), +              moduleTranslation.lookupValue(boundOp.getLowerBound()));        }      }    } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 08cac1f..5790a77 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -158,7 +158,8 @@ private:    /// Emit a cluster (subgraph). The specified builder generates the body of the    /// cluster. Return the anchor node of the cluster. -  Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { +  Node emitClusterStmt(function_ref<void()> builder, +                       const std::string &label = "") {      int clusterId = ++counter;      os << "subgraph cluster_" << clusterId << " {\n";      os.indent(); @@ -269,7 +270,7 @@ private:    }    /// Emit a node statement. -  Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, +  Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode,                      StringRef background = "") {      int nodeId = ++counter;      AttributeMap attrs; | 
