diff options
Diffstat (limited to 'mlir/lib')
20 files changed, 630 insertions, 225 deletions
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 82bdb84..74936e3 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -407,8 +407,8 @@ Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,                       "hexadecimal float constant out of range for type");    } -  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(), -                       intValue.getRawData()); +  APInt truncatedValue(typeSizeInBits, +                       ArrayRef(intValue.getRawData(), intValue.getNumWords()));    result.emplace(semantics, truncatedValue);    return success();  } diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index ba57155..03ed4d5 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -240,8 +240,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {  struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {    using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; -  using Adaptor = -      typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; +  using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;    LogicalResult    matchAndRewrite(arith::SelectOp op, Adaptor adaptor, diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index 798d8b0..b75968e 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {  /// op to llvm.br.  struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {    using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; -  using Adaptor = -      typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; +  using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;    LogicalResult    matchAndRewrite(cf::BranchOp op, Adaptor adaptor, @@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {  /// branch op to llvm.cond_br.  struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {    using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; -  using Adaptor = -      typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; +  using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;    LogicalResult    matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, @@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {    using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;    LogicalResult -  matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, +  matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,                    ConversionPatternRewriter &rewriter) const override {      // Get or convert default block.      FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index a2dfc12..a922338 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -68,7 +68,7 @@ struct ClampFOpConversion final        return LLVM::detail::handleMultidimensionalVectors(            op.getOperation(), adaptor.getOperands(), *getTypeConverter(),            [&](Type llvm1DVectorTy, ValueRange operands) -> Value { -            typename math::ClampFOp::Adaptor adaptor(operands); +            math::ClampFOp::Adaptor adaptor(operands);              return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,                                            adaptor.getValue(), adaptor.getMin(),                                            adaptor.getMax()); diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2e..de552ce 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {      VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());      if (!valOrResVecTy)        valOrResVecTy = VectorType::get(1, data.getType()); +    if (valOrResVecTy.getShape().size() != 1) +      return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");      int64_t elemBitWidth =          valOrResVecTy.getElementType().getIntOrFloatBitWidth(); diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 4d2d873..3d1a734 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,            .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })            .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })            .Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; }) +          .Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; }) +          .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; }) +          .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })            .Default([](Operation *) -> std::optional<arith::AtomicRMWKind> { -            // TODO: AtomicRMW supports other kinds of reductions this is -            // currently not detecting, add those when the need arises.              return std::nullopt;            });    if (!maybeKind) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 0992ce14..d478220 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -584,6 +584,10 @@ void ForOp::print(OpAsmPrinter &p) {  LogicalResult ForOp::verifyRegions() {    // Check that the body defines as single block argument for the induction    // variable. +  if (getBody()->getNumArguments() != 1) +    return emitOpError("expected body to have a single block argument for the " +                       "induction variable"); +    if (getInductionVar().getType() != getLowerBound().getType())      return emitOpError(          "expected induction variable to be same type as bounds and step"); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a5ffb9e..262d9b7 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -365,6 +365,59 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {    return success();  } +//===----------------------------------------------------------------------===// +// Stochastic Rounding Conversion Ops +//===----------------------------------------------------------------------===// + +LogicalResult ConvertF32x2ToF16x2Op::verify() { +  if (getRnd() != FPRoundingMode::RS) +    return emitOpError("Only RS rounding mode is supported for " +                       "conversions from f32x2 to f16x2."); +  return success(); +} + +LogicalResult ConvertF32x2ToBF16x2Op::verify() { +  if (getRnd() != FPRoundingMode::RS) +    return emitOpError("Only RS rounding mode is supported for " +                       "conversions from f32x2 to bf16x2."); +  return success(); +} + +LogicalResult ConvertF32x4ToF8x4Op::verify() { +  mlir::MLIRContext *ctx = getContext(); + +  if (!llvm::isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(getDstTy())) +    return emitOpError("Only ") +           << mlir::Float8E4M3FNType::get(ctx) << " and " +           << mlir::Float8E5M2Type::get(ctx) +           << " types are supported for conversions from f32x4 to f8x4."; + +  return success(); +} + +LogicalResult ConvertF32x4ToF6x4Op::verify() { +  mlir::MLIRContext *ctx = getContext(); + +  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) +    return emitOpError("Only ") +           << mlir::Float6E2M3FNType::get(ctx) << " and " +           << mlir::Float6E3M2FNType::get(ctx) +           << " types are supported for conversions from f32x4 to f6x4."; + +  return success(); +} + +LogicalResult ConvertF32x4ToF4x4Op::verify() { +  mlir::MLIRContext *ctx = getContext(); + +  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy())) +    return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx) +                                << " type is supported for conversions from " +                                   "f32x4 to f4x4."; + +  return success(); +} +  LogicalResult BulkStoreOp::verify() {    if (getInitVal() != 0)      return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -867,15 +920,40 @@ LogicalResult MmaOp::verify() {  }  LogicalResult ShflOp::verify() { -  if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid")) -    return success(); -  auto type = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); -  auto elementType = (type && type.getBody().size() == 2) -                         ? llvm::dyn_cast<IntegerType>(type.getBody()[1]) -                         : nullptr; -  if (!elementType || elementType.getWidth() != 1) -    return emitError("expected return type to be a two-element struct with " -                     "i1 as the second element"); +  auto returnStructType = llvm::dyn_cast<LLVM::LLVMStructType>(getType()); + +  auto verifyTypeError = [&](Twine desc, Type expectedType, +                             Type actualType) -> LogicalResult { +    return emitOpError("expected " + desc + " to be of type ") +           << expectedType << " but got " << actualType << " instead"; +  }; + +  if (returnStructType) { +    if (!getReturnValueAndIsValid()) +      return emitOpError("\"return_value_and_is_valid\" attribute must be " +                         "specified when the return type is a struct type"); + +    if (returnStructType.getBody().size() != 2) +      return emitOpError("expected return type to be a two-element struct"); + +    llvm::ArrayRef<Type> returnStruct = returnStructType.getBody(); +    auto resultType = returnStruct[0]; +    if (resultType != getVal().getType()) +      return verifyTypeError("first element in the returned struct", +                             getVal().getType(), resultType); + +    auto predicateType = returnStruct[1]; +    if (!predicateType.isInteger(1)) +      return verifyTypeError("second element in the returned struct", +                             mlir::IntegerType::get(getContext(), 1), +                             predicateType); +  } else { +    if (getReturnValueAndIsValid()) +      return emitOpError("expected return type to be a two-element struct"); + +    if (getType() != getVal().getType()) +      return verifyTypeError("return type", getVal().getType(), getType()); +  }    return success();  } @@ -1577,6 +1655,43 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {    return success();  } +LogicalResult NVVM::ReduxOp::verify() { +  mlir::Type reduxType = getType(); + +  if (!reduxType.isF32()) { +    if (getAbs()) +      return emitOpError("abs attribute is supported only for f32 type"); +    if (getNan()) +      return emitOpError("nan attribute is supported only for f32 type"); +  } + +  NVVM::ReduxKind kind = getKind(); +  switch (kind) { +  case NVVM::ReduxKind::ADD: +  case NVVM::ReduxKind::AND: +  case NVVM::ReduxKind::OR: +  case NVVM::ReduxKind::XOR: +  case NVVM::ReduxKind::MAX: +  case NVVM::ReduxKind::MIN: +  case NVVM::ReduxKind::UMAX: +  case NVVM::ReduxKind::UMIN: +    if (!reduxType.isInteger(32)) +      return emitOpError("'") +             << stringifyEnum(kind) << "' redux kind unsupported with " +             << reduxType << " type. Only supported type is 'i32'."; +    break; +  case NVVM::ReduxKind::FMIN: +  case NVVM::ReduxKind::FMAX: +    if (!reduxType.isF32()) +      return emitOpError("'") +             << stringifyEnum(kind) << "' redux kind unsupported with " +             << reduxType << " type. Only supported type is 'f32'."; +    break; +  } + +  return success(); +} +  /// Packs the given `field` into the `result`.  /// The `result` is 64-bits and each `field` can be 32-bits or narrower.  static llvm::Value * @@ -2469,6 +2584,85 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,      return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \    }() +llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { +  bool hasRelu = getRelu(); +  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); + +  if (hasRelu && hasSatFinite) +    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; +  if (hasRelu) +    return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; +  if (hasSatFinite) +    return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; +  return llvm::Intrinsic::nvvm_ff2f16x2_rs; +} + +llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { +  bool hasRelu = getRelu(); +  bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); + +  if (hasRelu && hasSatFinite) +    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; +  if (hasRelu) +    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; +  if (hasSatFinite) +    return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; +  return llvm::Intrinsic::nvvm_ff2bf16x2_rs; +} + +llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { +  mlir::Type dstTy = getDstTy(); +  bool hasRelu = getRelu(); + +  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) +      .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) { +        return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite +                       : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite; +      }) +      .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) { +        return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite +                       : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite; +      }) +      .Default([](mlir::Type) { +        llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op"); +        return llvm::Intrinsic::not_intrinsic; +      }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() { +  mlir::Type dstTy = getDstTy(); +  bool hasRelu = getRelu(); + +  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) +      .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) { +        return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite +                       : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite; +      }) +      .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) { +        return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite +                       : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite; +      }) +      .Default([](mlir::Type) { +        llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op"); +        return llvm::Intrinsic::not_intrinsic; +      }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() { +  mlir::Type dstTy = getDstTy(); +  bool hasRelu = getRelu(); + +  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) +      .Case<mlir::Float4E2M1FNType>([&](mlir::Float4E2M1FNType) { +        return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite +                       : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite; +      }) +      .Default([](mlir::Type) { +        llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op"); +        return llvm::Intrinsic::not_intrinsic; +      }); +} +  llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {    auto curOp = cast<NVVM::Tcgen05CpOp>(op);    bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; @@ -2508,6 +2702,9 @@ LogicalResult Tcgen05LdOp::verify() {    if (getShape() == NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && !getOffset())      result = emitError("shape 16x32bx2 requires offset argument"); +  if (getShape() != NVVM::Tcgen05LdStShape::SHAPE_16X32BX2 && getOffset()) +    result = emitError("offset argument is only supported for shape 16x32bx2"); +    auto resTy = getRes().getType();    unsigned resLen = isa<VectorType>(resTy)                          ? llvm::cast<VectorType>(resTy).getNumElements() diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1c21a2f..e271ac5 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2568,6 +2568,11 @@ computeCollapsedLayoutMap(MemRefType srcType,      auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();      auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);      for (int64_t idx : llvm::reverse(trailingReassocs)) { +      // Dimensions of size 1 should be skipped, because their strides are +      // meaningless and could have any arbitrary value. +      if (srcShape[idx - 1] == 1) +        continue; +        stride = stride * SaturatedInteger::wrap(srcShape[idx]);        // Both source and result stride must have the same static value. In that @@ -2582,11 +2587,6 @@ computeCollapsedLayoutMap(MemRefType srcType,        if (strict && (stride.saturated || srcStride.saturated))          return failure(); -      // Dimensions of size 1 should be skipped, because their strides are -      // meaningless and could have any arbitrary value. -      if (srcShape[idx - 1] == 1) -        continue; -        if (!stride.saturated && !srcStride.saturated && stride != srcStride)          return failure();      } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 35eba72..b2f1d84 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -3068,8 +3068,12 @@ LogicalResult acc::LoopOp::verify() {    if (getRegion().empty())      return emitError("expected non-empty body."); -  // When it is container-like - it is expected to hold a loop-like operation. -  if (isContainerLike()) { +  if (getUnstructured()) { +    if (!isContainerLike()) +      return emitError( +          "unstructured acc.loop must not have induction variables"); +  } else if (isContainerLike()) { +    // When it is container-like - it is expected to hold a loop-like operation.      // Obtain the maximum collapse count - we use this to check that there      // are enough loops contained.      uint64_t collapseCount = getCollapseValue().value_or(1); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 2946b53..881e256 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2565,6 +2565,39 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {  struct ConditionPropagation : public OpRewritePattern<IfOp> {    using OpRewritePattern<IfOp>::OpRewritePattern; +  /// Kind of parent region in the ancestor cache. +  enum class Parent { Then, Else, None }; + +  /// Returns the kind of region ("then", "else", or "none") of the +  /// IfOp that the given region is transitively nested in. Updates +  /// the cache accordingly. +  static Parent getParentType(Region *toCheck, IfOp op, +                              DenseMap<Region *, Parent> &cache, +                              Region *endRegion) { +    SmallVector<Region *> seen; +    while (toCheck != endRegion) { +      auto found = cache.find(toCheck); +      if (found != cache.end()) +        return found->second; +      seen.push_back(toCheck); +      if (&op.getThenRegion() == toCheck) { +        for (Region *region : seen) +          cache[region] = Parent::Then; +        return Parent::Then; +      } +      if (&op.getElseRegion() == toCheck) { +        for (Region *region : seen) +          cache[region] = Parent::Else; +        return Parent::Else; +      } +      toCheck = toCheck->getParentRegion(); +    } + +    for (Region *region : seen) +      cache[region] = Parent::None; +    return Parent::None; +  } +    LogicalResult matchAndRewrite(IfOp op,                                  PatternRewriter &rewriter) const override {      // Early exit if the condition is constant since replacing a constant @@ -2580,9 +2613,12 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {      Value constantTrue = nullptr;      Value constantFalse = nullptr; +    DenseMap<Region *, Parent> cache;      for (OpOperand &use :           llvm::make_early_inc_range(op.getCondition().getUses())) { -      if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { +      switch (getParentType(use.getOwner()->getParentRegion(), op, cache, +                            op.getCondition().getParentRegion())) { +      case Parent::Then: {          changed = true;          if (!constantTrue) @@ -2591,8 +2627,9 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {          rewriter.modifyOpInPlace(use.getOwner(),                                   [&]() { use.set(constantTrue); }); -      } else if (op.getElseRegion().isAncestor( -                     use.getOwner()->getParentRegion())) { +        break; +      } +      case Parent::Else: {          changed = true;          if (!constantFalse) @@ -2601,6 +2638,10 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {          rewriter.modifyOpInPlace(use.getOwner(),                                   [&]() { use.set(constantFalse); }); +        break; +      } +      case Parent::None: +        break;        }      } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 7a26cd3..1fbcf5f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1050,7 +1050,7 @@ public:  /// Sparse codegen rule for position accesses.  class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {  public: -  using OpAdaptor = typename ToPositionsOp::Adaptor; +  using OpAdaptor = ToPositionsOp::Adaptor;    using OpConversionPattern<ToPositionsOp>::OpConversionPattern;    LogicalResult    matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor, @@ -1073,7 +1073,7 @@ public:  class SparseToCoordinatesConverter      : public OpConversionPattern<ToCoordinatesOp> {  public: -  using OpAdaptor = typename ToCoordinatesOp::Adaptor; +  using OpAdaptor = ToCoordinatesOp::Adaptor;    using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;    LogicalResult    matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor, @@ -1099,7 +1099,7 @@ public:  class SparseToCoordinatesBufferConverter      : public OpConversionPattern<ToCoordinatesBufferOp> {  public: -  using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor; +  using OpAdaptor = ToCoordinatesBufferOp::Adaptor;    using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;    LogicalResult    matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor, @@ -1121,7 +1121,7 @@ public:  /// Sparse codegen rule for value accesses.  class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {  public: -  using OpAdaptor = typename ToValuesOp::Adaptor; +  using OpAdaptor = ToValuesOp::Adaptor;    using OpConversionPattern<ToValuesOp>::OpConversionPattern;    LogicalResult    matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ae3423c..daef0ba 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,    case arith::AtomicRMWKind::ori:      return vector::ReductionOp::create(builder, vector.getLoc(),                                         CombiningKind::OR, vector); -  // TODO: Add remaining reduction operations. +  case arith::AtomicRMWKind::minnumf: +    return vector::ReductionOp::create(builder, vector.getLoc(), +                                       CombiningKind::MINNUMF, vector); +  case arith::AtomicRMWKind::maxnumf: +    return vector::ReductionOp::create(builder, vector.getLoc(), +                                       CombiningKind::MAXNUMF, vector); +  case arith::AtomicRMWKind::xori: +    return vector::ReductionOp::create(builder, vector.getLoc(), +                                       CombiningKind::XOR, vector);    default:      (void)emitOptionalError(loc, "Reduction operation type not supported");      break; diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 83406c8..397107b 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {        >();  } -/// Generates instructions to compute offsets for a subgroup identified by -/// its multidimensional indices (sgId), using the specified subgroup layout -/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data -/// dimensions (sizePerWg). +// A `srcShape` consists of N distribution units, each being `subShapesLayout` x +// `subShape`. A `delinearizedId` is used to identify a particular `subShape` +// within each distribution unit. +// Example: +// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a +// distribution unit of shape 64x64, we have 2x4 such distribution units. +// `delinearizedId` is used to identify a 16x32 of a subgroup in each +// distribution unit.  static SmallVector<SmallVector<Value>> -genOffsetsComputingInsts(OpBuilder &builder, Location loc, -                         SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout, -                         ArrayRef<int64_t> sizePerSg, -                         ArrayRef<int64_t> sizePerWg) { - -  SmallVector<SmallVector<Value>> offsets; +genCoordinates(OpBuilder &builder, Location loc, +               SmallVector<Value> delinearizedId, +               ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape, +               ArrayRef<int64_t> srcShape) { +  SmallVector<SmallVector<Value>> coordinates; + +  // A distribution unit must be less than or equal to `srcShape` +  SmallVector<int64_t> distUnitShape = llvm::map_to_vector( +      llvm::zip_equal(srcShape, +                      computeElementwiseMul(subShapesLayout, subShape)), +      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); -  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i] -  SmallVector<Value> localOffsets = llvm::map_to_vector( -      llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value { +  // Get the offset of `subShape` within a distribution unit. +  SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector( +      llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {          return builder.createOrFold<index::MulOp>(              loc, std::get<0>(t),              builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));        }); -  // distUnit[i] is the minimum value between sizePerWg[i] and -  // sgLayout[i] * sizePerSg[i] -  SmallVector<int64_t> distUnit = llvm::map_to_vector( -      llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)), -      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); }); - +  // For each dist unit    for (SmallVector<int64_t> unitOffs : -       StaticTileOffsetRange(sizePerWg, distUnit)) { +       StaticTileOffsetRange(srcShape, distUnitShape)) { +    // Get dist unit offset within `srcShape`.      SmallVector<Value> base =          llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {            return arith::ConstantIndexOp::create(builder, loc, d);          }); - -    SmallVector<Value> adds = llvm::map_to_vector( -        llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value { -          return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t), -                                                     std::get<1>(t)); -        }); - +    // Calculate `subShape` offset within `srcShape`. +    SmallVector<Value> adds = +        llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset), +                            [&](const auto &t) -> Value { +                              return builder.createOrFold<arith::AddIOp>( +                                  loc, std::get<0>(t), std::get<1>(t)); +                            }); +    // Do not go beyond `srcShape` bounds.      SmallVector<Value> mods = llvm::map_to_vector( -        llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { +        llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {            return builder.createOrFold<index::RemUOp>(                loc, std::get<0>(t),                arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));          }); -    offsets.push_back(mods); +    coordinates.push_back(mods);    } -  return offsets; +  return coordinates;  }  // Checks if the given shape can be evenly distributed based on the layout @@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,  }  FailureOr<SmallVector<Value>> -LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, -                                  Value linearId) { -  // delinearizeSubgroupId is only available for -  // workgroup-level layout attribute -  if (!isForWorkgroup()) -    return failure(); +LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {    // TODO: handle order attribute    auto hasDefaultOrder = [&]() { @@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,    };    if (!hasDefaultOrder())      return mlir::emitError(loc, "order attribute is currently not supported."); - -  auto dims = -      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value { -        return builder.createOrFold<arith::ConstantIndexOp>(loc, d); -      }); +  SmallVector<int64_t> layout; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +  } else { +    return failure(); +  } +  auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value { +    return builder.createOrFold<arith::ConstantIndexOp>(loc, d); +  });    return affine::delinearizeIndex(builder, loc, linearId, dims);  } -/// Implements DistributeLayoutAttr::getOffsets to generate +/// Implements DistributeLayoutAttr::computeDistributedCoords to generate  /// instructions for computing multi-dimensional offsets when distributed by  /// LayoutAttr.  FailureOr<SmallVector<SmallVector<Value>>> -LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, -                       ArrayRef<int64_t> shape) { -  if (!isForWorkgroup()) +LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc, +                                     Value linearId, ArrayRef<int64_t> shape) { +  SmallVector<int64_t> layout; +  SmallVector<int64_t> subShape; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +    subShape = getEffectiveSgDataAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +    subShape = getEffectiveLaneDataAsInt(); +  } else {      return failure(); - -  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); -  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); -  if (sgShape.empty()) { -    if (auto derivedShape = computeShapeRatio(shape, sgLayout)) -      sgShape = derivedShape.value(); +  } +  if (subShape.empty()) { +    if (auto derivedShape = computeShapeRatio(shape, layout)) +      subShape = derivedShape.value();      else        return failure();    }    // delinearize Ids -  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); +  auto maybeIds = delinearizeId(builder, loc, linearId);    if (failed(maybeIds))      return failure(); -  SmallVector<Value> sgIds = *maybeIds; +  SmallVector<Value> ids = *maybeIds; -  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, -                                  shape); +  return genCoordinates(builder, loc, ids, layout, subShape, shape);  }  //===----------------------------------------------------------------------===// @@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {  }  FailureOr<SmallVector<Value>> -SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, -                                 Value linearId) { +SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {    SliceAttr attr = flatten();    auto parent = dyn_cast<LayoutAttr>(attr.getParent()); -  return parent.delinearizeSubgroupId(builder, loc, linearId); +  return parent.delinearizeId(builder, loc, linearId);  } -/// Implements DistributeLayoutAttr::getOffsets to generate -/// instructions for computing multi-dimensional offsets when distributed by -/// SliceAttr. +// Implements DistributeLayoutAttr::computeDistributedCoords to generate +// instructions for computing multi-dimensional offsets when distributed by +// LayoutAttr.  FailureOr<SmallVector<SmallVector<Value>>> -SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, -                      ArrayRef<int64_t> shape) { +SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc, +                                    Value linearId, ArrayRef<int64_t> shape) {    assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");    if (!isForWorkgroup())      return failure(); -  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt(); -  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt(); -  if (sgShape.empty()) { -    if (auto derivedShape = computeShapeRatio(shape, sgLayout)) -      sgShape = derivedShape.value(); +  SmallVector<int64_t> layout; +  SmallVector<int64_t> subShape; +  if (isForWorkgroup()) { +    layout = getEffectiveSgLayoutAsInt(); +    subShape = getEffectiveSgDataAsInt(); +  } else if (isForSubgroup()) { +    layout = getEffectiveLaneLayoutAsInt(); +    subShape = getEffectiveLaneDataAsInt(); +  } else { +    return failure(); +  } + +  if (subShape.empty()) { +    if (auto derivedShape = computeShapeRatio(shape, layout)) +      subShape = derivedShape.value();      else        return failure();    }    // delinearize Ids -  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); +  auto maybeIds = delinearizeId(builder, loc, linearId);    if (failed(maybeIds))      return failure(); @@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,    SmallVector<Value> sgIds =        XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims); -  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, -                                  shape); +  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);  }  bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) { diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2..7b6c4b6 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,  LogicalResult  IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, -                      UnitAttr subgroup_block_io, +                      UnitAttr subgroup_block_io, DistributeLayoutAttr layout,                        function_ref<InFlightDiagnostic()> emitError) {    if (!dataTy) {      if (subgroup_block_io)        return emitError() << "subgroup_block_io " -                            "are only allowed when result is a 1D VectorType."; +                            "are only allowed when result is a VectorType.";      else        return success();    } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,    ArrayRef<int64_t> dataShape = dataTy.getShape();    ArrayRef<int64_t> mdescShape = mdescTy.getShape(); +  SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); +  ArrayAttr strideAttr = mdescTy.getStrideAttr(); +  SmallVector<int64_t> strides; +  for (Attribute attr : strideAttr.getValue()) { +    strides.push_back(cast<IntegerAttr>(attr).getInt()); +  } +  if (subgroup_block_io && layout) { +    auto laneData = layout.getEffectiveLaneDataAsInt(); +    auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); +    if (!laneData.empty()) { +      bool isLaneDataContiguous = +          std::all_of(laneData.begin(), std::prev(laneData.end()), +                      [](int x) { return x == 1; }); +      if (!isLaneDataContiguous) +        return emitError() << "With subgroup_block_io, accessed data must be " +                              "contiguous and coalesced."; +      for (size_t i = 0; i < laneData.size(); ++i) { +        if (laneLayout[i] != blockShape[i]) +          return emitError() << "With subgroup_block_io, the block shape must " +                                "match the lane layout."; +        if (laneLayout[i] != 1 && strides[i] != 1) +          return emitError() << "With subgroup_block_io, the distributed " +                                "dimensions must be contiguous."; +      } +    } +  }    if (dataShape.size() == 2) { -    if (subgroup_block_io) -      return emitError() << "subgroup_block_io " -                            "are only allowed when result is a 1D VectorType.";      if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),                       [](auto p) { return std::get<0>(p) > std::get<1>(p); }))        return emitError() << "data shape must not exceed mem_desc shape.";    } else { -    SmallVector<int64_t> blockShape = mdescTy.getBlockShape();      // if the subgroup_block_io attribute is set,  mdescTy must have block      // attribute      if (subgroup_block_io && !blockShape.size()) @@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() {    MemDescType mdescTy = getMemDesc().getType();    return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, -                               [&]() { return emitError(); }); +                               getLayoutAttr(), [&]() { return emitError(); });  }  //===----------------------------------------------------------------------===// @@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() {    UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();    MemDescType mdescTy = getMemDesc().getType();    return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, -                               [&]() { return emitError(); }); +                               getLayoutAttr(), [&]() { return emitError(); });  }  namespace mlir { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 5a3b27e..bbd7733 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -7,6 +7,7 @@  //===----------------------------------------------------------------------===//  #include "mlir/Dialect/GPU/IR/GPUDialect.h"  #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h"  #include "mlir/Dialect/MemRef/IR/MemRef.h"  #include "mlir/Dialect/Vector/IR/VectorOps.h"  #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" @@ -912,6 +913,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {    }  }; +static SmallVector<Value> computeDistributedCoordinatesForMatrixOp( +    PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout, +    Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) { +  SmallVector<Value> newCoods; +  auto maybeCoords = +      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape); +  if (failed(maybeCoords)) +    return {}; +  assert(maybeCoords.value().size() == 1 && +         "Expected one set of distributed offsets"); +  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned( +      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]), +      getAsOpFoldResult(origOffsets)); +  newCoods = llvm::to_vector(llvm::map_range( +      ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); })); +  return newCoods; +} + +/// Pattern for distributing xegpu::LoadMatrixOp. +struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern { +  using gpu::WarpDistributionPattern::WarpDistributionPattern; +  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, +                                PatternRewriter &rewriter) const override { +    gpu::YieldOp yield = warpOp.getTerminator(); +    Operation *lastNode = yield->getPrevNode(); +    auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode); +    if (!matrixOp) +      return failure(); + +    OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) { +      return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op; +    }); +    if (!producedByLastLoad) +      return rewriter.notifyMatchFailure( +          warpOp, "The last op is not xegpu::LoadMatrixOp"); +    const int operandIdx = producedByLastLoad->getOperandNumber(); + +    VectorType sgPayloadTy = +        dyn_cast<VectorType>(matrixOp.getResult().getType()); +    VectorType warpResultTy = +        cast<VectorType>(warpOp.getResult(operandIdx).getType()); +    if (!sgPayloadTy) +      return rewriter.notifyMatchFailure( +          matrixOp, "the matrix op payload must be a vector type"); + +    auto loc = matrixOp.getLoc(); +    auto offsets = matrixOp.getMixedOffsets(); +    if (offsets.empty()) +      return rewriter.notifyMatchFailure(matrixOp, +                                         "the load op must have offsets"); +    SmallVector<Value> offsetsAsValues = +        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + +    auto layout = matrixOp.getLayoutAttr(); +    if (!layout) +      return rewriter.notifyMatchFailure( +          matrixOp, "the matrix operation lacks layout attribute"); + +    FailureOr<VectorType> distPayloadByWarpOpOrFailure = +        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); +    if (failed(distPayloadByWarpOpOrFailure)) +      return rewriter.notifyMatchFailure( +          matrixOp, "Failed to distribute matrix op payload based on layout."); + +    SmallVector<Value> operands = {matrixOp.getMemDesc()}; +    const unsigned offsetsStartIdx = operands.size(); +    operands.append(offsetsAsValues); + +    SmallVector<Type> operandTypes = llvm::to_vector( +        llvm::map_range(operands, [](Value v) { return v.getType(); })); + +    SmallVector<size_t> newRetIndices; +    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( +        rewriter, warpOp, operands, operandTypes, newRetIndices); +    SmallVector<Value> newOperands = llvm::map_to_vector( +        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + +    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; +    std::fill(newConstOffsets.begin(), newConstOffsets.end(), +              ShapedType::kDynamic); +    DenseI64ArrayAttr newConstOffsetsAttr = +        rewriter.getDenseI64ArrayAttr(newConstOffsets); +    ValueRange currentOffsets = +        ValueRange(newOperands).drop_front(offsetsStartIdx); + +    SmallVector<Value> newCoords = currentOffsets; +    rewriter.setInsertionPointAfter(newWarpOp); + +    if (!matrixOp.getSubgroupBlockIoAttr()) { +      newCoords = computeDistributedCoordinatesForMatrixOp( +          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), +          currentOffsets); +    } +    xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create( +        rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure, +        newOperands[0], ValueRange(newCoords), newConstOffsetsAttr, +        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); +    // Resolve the output type and replace all uses. +    rewriter.replaceAllUsesWith( +        newWarpOp.getResult(operandIdx), +        resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter)); +    return success(); +  } +}; + +/// Pattern for distributing xegpu::StoreMatrixOp. +struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern { +  using gpu::WarpDistributionPattern::WarpDistributionPattern; +  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, +                                PatternRewriter &rewriter) const override { +    gpu::YieldOp yield = warpOp.getTerminator(); +    Operation *lastNode = yield->getPrevNode(); +    auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode); +    if (!matrixOp) +      return failure(); + +    VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType()); +    if (!sgPayloadTy) +      return rewriter.notifyMatchFailure( +          matrixOp, "the matrix op payload must be a vector type"); + +    auto loc = matrixOp.getLoc(); +    auto offsets = matrixOp.getMixedOffsets(); +    if (offsets.empty()) +      return rewriter.notifyMatchFailure(matrixOp, +                                         "the store op must have offsets"); +    SmallVector<Value> offsetsAsValues = +        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets); + +    auto layout = matrixOp.getLayoutAttr(); +    if (!layout) +      return rewriter.notifyMatchFailure( +          matrixOp, "the matrix operation lacks layout attribute"); + +    FailureOr<VectorType> distPayloadByWarpOpOrFailure = +        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy); +    if (failed(distPayloadByWarpOpOrFailure)) +      return rewriter.notifyMatchFailure( +          matrixOp, "Failed to distribute matrix op payload based on layout."); + +    SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()}; +    const unsigned offsetsStartIdx = operands.size(); +    operands.append(offsetsAsValues); + +    SmallVector<Type> operandTypes = llvm::to_vector( +        llvm::map_range(operands, [](Value v) { return v.getType(); })); +    operandTypes[0] = *distPayloadByWarpOpOrFailure; + +    SmallVector<size_t> newRetIndices; +    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( +        rewriter, warpOp, operands, operandTypes, newRetIndices); +    SmallVector<Value> newOperands = llvm::map_to_vector( +        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); }); + +    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()}; +    std::fill(newConstOffsets.begin(), newConstOffsets.end(), +              ShapedType::kDynamic); +    DenseI64ArrayAttr newConstOffsetsAttr = +        rewriter.getDenseI64ArrayAttr(newConstOffsets); +    ValueRange currentOffsets = +        ValueRange(newOperands).drop_front(offsetsStartIdx); + +    SmallVector<Value> newCoords = currentOffsets; +    rewriter.setInsertionPointAfter(newWarpOp); + +    if (!matrixOp.getSubgroupBlockIoAttr()) { +      newCoords = computeDistributedCoordinatesForMatrixOp( +          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(), +          currentOffsets); +    } + +    xegpu::StoreMatrixOp::create( +        rewriter, loc, TypeRange{}, newOperands[0], newOperands[1], +        ValueRange(newCoords), newConstOffsetsAttr, +        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{}); +    rewriter.eraseOp(matrixOp); +    return success(); +  } +}; +  /// Distribute a scattered load op. The logic and requirements are the same as  /// for the scattered store distribution. The warpOp's payload vector is  /// expected to be distributed by the load's result consumer. @@ -1443,7 +1624,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(                 LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,                 GpuBarrierDistribution, VectorMultiReductionDistribution,                 LoadDistribution, StoreDistribution, VectorTransposeDistribution, -               VectorBitcastDistribution, +               VectorBitcastDistribution, LoadMatrixDistribution, +               StoreMatrixDistribution,                 MemrefExtractAlignedPointerAsIndexDistribution>(        patterns.getContext(),        /*pattern benefit=*/regularPatternBenefit); @@ -1468,6 +1650,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {        // Layouts are needed for vector type only.        if (!isa<VectorType>(operand.get().getType()))          continue; +      if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op)) +        continue;        auto layout = xegpu::getDistributeLayoutAttr(operand.get());        if (!layout) { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9fc5ad9..79eea55 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,    // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory    // descriptors to be accessed, based on the layout information.    ArrayRef<int64_t> wgShape = op.getDataShape(); -  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); +  auto maybeDescOffsets = +      layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);    if (failed(maybeDescOffsets))      return failure(); @@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {        // Get subgroup id        Value sgId =            gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - -      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); +      auto sgOffsets = +          layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);        if (failed(sgOffsets))          return failure(); @@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {      Value sgId =          gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); -    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); +    auto sgOffsets = +        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);      if (failed(sgOffsets))        return failure(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index ce421f4..8212d6d 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -463,28 +463,26 @@ void Operation::updateOrderIfNecessary() {  //===----------------------------------------------------------------------===//  auto llvm::ilist_detail::SpecificNodeAccess< -    typename llvm::ilist_detail::compute_node_options< -        ::mlir::Operation>::type>::getNodePtr(pointer n) -> node_type * { +    llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>:: +    getNodePtr(pointer n) -> node_type * {    return NodeAccess::getNodePtr<OptionsT>(n);  }  auto llvm::ilist_detail::SpecificNodeAccess< -    typename llvm::ilist_detail::compute_node_options< -        ::mlir::Operation>::type>::getNodePtr(const_pointer n) -    -> const node_type * { +    llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>:: +    getNodePtr(const_pointer n) -> const node_type * {    return NodeAccess::getNodePtr<OptionsT>(n);  }  auto llvm::ilist_detail::SpecificNodeAccess< -    typename llvm::ilist_detail::compute_node_options< -        ::mlir::Operation>::type>::getValuePtr(node_type *n) -> pointer { +    llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>:: +    getValuePtr(node_type *n) -> pointer {    return NodeAccess::getValuePtr<OptionsT>(n);  }  auto llvm::ilist_detail::SpecificNodeAccess< -    typename llvm::ilist_detail::compute_node_options< -        ::mlir::Operation>::type>::getValuePtr(const node_type *n) -    -> const_pointer { +    llvm::ilist_detail::compute_node_options<::mlir::Operation>::type>:: +    getValuePtr(const node_type *n) -> const_pointer {    return NodeAccess::getValuePtr<OptionsT>(n);  } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 3d86b09..0964e1b 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -36,9 +36,6 @@ using mlir::LLVM::detail::createIntrinsicCall;  static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,                                                 NVVM::ReduxKind kind,                                                 bool hasAbs, bool hasNaN) { -  if (!(resultType->isIntegerTy(32) || resultType->isFloatTy())) -    llvm_unreachable("unsupported data type for redux"); -    switch (kind) {    case NVVM::ReduxKind::ADD:      return llvm::Intrinsic::nvvm_redux_sync_add; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3a23bbf..2fe0697 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {    /// A set of operations that were modified by the current pattern.    SetVector<Operation *> patternModifiedOps; -  /// A set of blocks that were inserted (newly-created blocks or moved blocks) -  /// by the current pattern. -  SetVector<Block *> patternInsertedBlocks; -    /// A list of unresolved materializations that were created by the current    /// pattern.    DenseSet<UnrealizedConversionCastOp> patternMaterializations; @@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(    if (!config.allowPatternRollback && config.listener)      config.listener->notifyBlockInserted(block, previous, previousIt); -  patternInsertedBlocks.insert(block); -    if (wasDetached) {      // If the block was detached, it is most likely a newly created block.      if (config.allowPatternRollback) { @@ -2399,17 +2393,12 @@ private:    bool canApplyPattern(Operation *op, const Pattern &pattern);    /// Legalize the resultant IR after successfully applying the given pattern. -  LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, -                                      const RewriterState &curState, -                                      const SetVector<Operation *> &newOps, -                                      const SetVector<Operation *> &modifiedOps, -                                      const SetVector<Block *> &insertedBlocks); - -  /// Legalizes the actions registered during the execution of a pattern.    LogicalResult -  legalizePatternBlockRewrites(Operation *op, -                               const SetVector<Block *> &insertedBlocks, -                               const SetVector<Operation *> &newOps); +  legalizePatternResult(Operation *op, const Pattern &pattern, +                        const RewriterState &curState, +                        const SetVector<Operation *> &newOps, +                        const SetVector<Operation *> &modifiedOps); +    LogicalResult    legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);    LogicalResult @@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {    auto cleanup = llvm::make_scope_exit([&]() {      rewriterImpl.patternNewOps.clear();      rewriterImpl.patternModifiedOps.clear(); -    rewriterImpl.patternInsertedBlocks.clear();    });    // Upon failure, undo all changes made by the folder. @@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {  static void  reportNewIrLegalizationFatalError(const Pattern &pattern,                                    const SetVector<Operation *> &newOps, -                                  const SetVector<Operation *> &modifiedOps, -                                  const SetVector<Block *> &insertedBlocks) { +                                  const SetVector<Operation *> &modifiedOps) {    auto newOpNames = llvm::map_range(        newOps, [](Operation *op) { return op->getName().getStringRef(); });    auto modifiedOpNames = llvm::map_range(        modifiedOps, [](Operation *op) { return op->getName().getStringRef(); }); -  StringRef detachedBlockStr = "(detached block)"; -  auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) { -    if (block->getParentOp()) -      return block->getParentOp()->getName().getStringRef(); -    return detachedBlockStr; -  }); -  llvm::report_fatal_error( -      "pattern '" + pattern.getDebugName() + -      "' produced IR that could not be legalized. " + "new ops: {" + -      llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" + -      llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" + -      llvm::join(insertedBlockNames, ", ") + "}"); +  llvm::report_fatal_error("pattern '" + pattern.getDebugName() + +                           "' produced IR that could not be legalized. " + +                           "new ops: {" + llvm::join(newOpNames, ", ") + "}, " + +                           "modified ops: {" + +                           llvm::join(modifiedOpNames, ", ") + "}");  }  LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { @@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {      }      rewriterImpl.patternNewOps.clear();      rewriterImpl.patternModifiedOps.clear(); -    rewriterImpl.patternInsertedBlocks.clear();      LLVM_DEBUG({        logFailure(rewriterImpl.logger, "pattern failed to match");        if (rewriterImpl.config.notifyCallback) { @@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {      SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);      SetVector<Operation *> modifiedOps =          moveAndReset(rewriterImpl.patternModifiedOps); -    SetVector<Block *> insertedBlocks = -        moveAndReset(rewriterImpl.patternInsertedBlocks); -    auto result = legalizePatternResult(op, pattern, curState, newOps, -                                        modifiedOps, insertedBlocks); +    auto result = +        legalizePatternResult(op, pattern, curState, newOps, modifiedOps);      appliedPatterns.erase(&pattern);      if (failed(result)) {        if (!rewriterImpl.config.allowPatternRollback) -        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, -                                          insertedBlocks); +        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);        rewriterImpl.resetState(curState, pattern.getDebugName());      }      if (config.listener) @@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,  LogicalResult OperationLegalizer::legalizePatternResult(      Operation *op, const Pattern &pattern, const RewriterState &curState,      const SetVector<Operation *> &newOps, -    const SetVector<Operation *> &modifiedOps, -    const SetVector<Block *> &insertedBlocks) { +    const SetVector<Operation *> &modifiedOps) {    [[maybe_unused]] auto &impl = rewriter.getImpl();    assert(impl.pendingRootUpdates.empty() && "dangling root updates"); @@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(  #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS    // Legalize each of the actions registered during application. -  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) || -      failed(legalizePatternRootUpdates(modifiedOps)) || +  if (failed(legalizePatternRootUpdates(modifiedOps)) ||        failed(legalizePatternCreatedOperations(newOps))) {      return failure();    } @@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(    return success();  } -LogicalResult OperationLegalizer::legalizePatternBlockRewrites( -    Operation *op, const SetVector<Block *> &insertedBlocks, -    const SetVector<Operation *> &newOps) { -  ConversionPatternRewriterImpl &impl = rewriter.getImpl(); -  SmallPtrSet<Operation *, 16> alreadyLegalized; - -  // If the pattern moved or created any blocks, make sure the types of block -  // arguments get legalized. -  for (Block *block : insertedBlocks) { -    if (impl.erasedBlocks.contains(block)) -      continue; - -    // Only check blocks outside of the current operation. -    Operation *parentOp = block->getParentOp(); -    if (!parentOp || parentOp == op || block->getNumArguments() == 0) -      continue; - -    // If the region of the block has a type converter, try to convert the block -    // directly. -    if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { -      std::optional<TypeConverter::SignatureConversion> conversion = -          converter->convertBlockSignature(block); -      if (!conversion) { -        LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " -                                           "block")); -        return failure(); -      } -      impl.applySignatureConversion(block, converter, *conversion); -      continue; -    } - -    // Otherwise, try to legalize the parent operation if it was not generated -    // by this pattern. This is because we will attempt to legalize the parent -    // operation, and blocks in regions created by this pattern will already be -    // legalized later on. -    if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { -      if (failed(legalize(parentOp))) { -        LLVM_DEBUG(logFailure( -            impl.logger, "operation '{0}'({1}) became illegal after rewrite", -            parentOp->getName(), parentOp)); -        return failure(); -      } -    } -  } -  return success(); -} -  LogicalResult OperationLegalizer::legalizePatternCreatedOperations(      const SetVector<Operation *> &newOps) {    for (Operation *op : newOps) { @@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,    TypeConverter::SignatureConversion result(type.getNumInputs());    SmallVector<Type, 1> newResults;    if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || -      failed(typeConverter.convertTypes(type.getResults(), newResults)) || -      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), -                                         typeConverter, &result))) +      failed(typeConverter.convertTypes(type.getResults(), newResults)))      return failure(); +  if (!funcOp.getFunctionBody().empty()) +    rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result, +                                      &typeConverter);    // Update the function signature in-place.    auto newType = FunctionType::get(rewriter.getContext(),  | 
