diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 38 | 
1 files changed, 30 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index abd12e2..7b6c4b6 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,  LogicalResult  IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy, -                      UnitAttr subgroup_block_io, +                      UnitAttr subgroup_block_io, DistributeLayoutAttr layout,                        function_ref<InFlightDiagnostic()> emitError) {    if (!dataTy) {      if (subgroup_block_io)        return emitError() << "subgroup_block_io " -                            "are only allowed when result is a 1D VectorType."; +                            "are only allowed when result is a VectorType.";      else        return success();    } @@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,    ArrayRef<int64_t> dataShape = dataTy.getShape();    ArrayRef<int64_t> mdescShape = mdescTy.getShape(); +  SmallVector<int64_t> blockShape = mdescTy.getBlockShape(); +  ArrayAttr strideAttr = mdescTy.getStrideAttr(); +  SmallVector<int64_t> strides; +  for (Attribute attr : strideAttr.getValue()) { +    strides.push_back(cast<IntegerAttr>(attr).getInt()); +  } +  if (subgroup_block_io && layout) { +    auto laneData = layout.getEffectiveLaneDataAsInt(); +    auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); +    if (!laneData.empty()) { +      bool isLaneDataContiguous = +          std::all_of(laneData.begin(), std::prev(laneData.end()), +                      [](int x) { return x == 1; }); +      if (!isLaneDataContiguous) +        return emitError() << "With subgroup_block_io, accessed data must be " +                              "contiguous and coalesced."; +      for (size_t i = 0; i < laneData.size(); ++i) { +        if (laneLayout[i] != blockShape[i]) +          return emitError() << "With subgroup_block_io, the block shape must " +                                "match the lane layout."; +        if (laneLayout[i] != 1 && strides[i] != 1) +          return emitError() << "With subgroup_block_io, the distributed " +                                "dimensions must be contiguous."; +      } +    } +  }    if (dataShape.size() == 2) { -    if (subgroup_block_io) -      return emitError() << "subgroup_block_io " -                            "are only allowed when result is a 1D VectorType.";      if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),                       [](auto p) { return std::get<0>(p) > std::get<1>(p); }))        return emitError() << "data shape must not exceed mem_desc shape.";    } else { -    SmallVector<int64_t> blockShape = mdescTy.getBlockShape();      // if the subgroup_block_io attribute is set,  mdescTy must have block      // attribute      if (subgroup_block_io && !blockShape.size()) @@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() {    MemDescType mdescTy = getMemDesc().getType();    return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io, -                               [&]() { return emitError(); }); +                               getLayoutAttr(), [&]() { return emitError(); });  }  //===----------------------------------------------------------------------===// @@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() {    UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();    MemDescType mdescTy = getMemDesc().getType();    return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io, -                               [&]() { return emitError(); }); +                               getLayoutAttr(), [&]() { return emitError(); });  }  namespace mlir {  | 
