diff options
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR')
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 57 | 
4 files changed, 60 insertions, 3 deletions
| diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3eae67f..2731069 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,                         return structType.getBody()[memberIndex];                       return nullptr;                     }) -                   .Default(Type(nullptr)); +                   .Default(nullptr);    }  } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index cee943d..7d9058c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,            .Case<IntegerType, FloatType>([](auto type) {              return type.getWidth() % 8 == 0 && type.getWidth() > 0;            }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!canConvertType)      return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ac35eea..ce93d18 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {            // clang-format on            .Case<PtrLikeTypeInterface>(                [](Type type) { return isCompatiblePtrType(type); }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!result)      compatibleTypes.erase(type); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4db..a5ffb9e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -896,6 +896,12 @@ std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,    } else if (type == NVVM::MMATypes::f32) {      elementType = builder.getF32Type();      numberElements = 8; +  } else if (type == NVVM::MMATypes::f64) { +    elementType = builder.getF64Type(); +    if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) +      numberElements = 1; +    else +      numberElements = 2;    } else if (type == NVVM::MMATypes::tf32) {      elementType = builder.getI32Type();      numberElements = 4; @@ -954,6 +960,14 @@ LogicalResult NVVM::WMMALoadOp::verify() {      return emitOpError() << "invalid attribute combination";    std::pair<Type, unsigned> typeInfo = inferMMATypeFromMNK(        getEltype(), getFrag(), getM(), getN(), getK(), getContext()); +  // Special case for f64 fragments +  Type f64Ty = Float64Type::get(getContext()); +  if (typeInfo.first == f64Ty && typeInfo.second == 1) { +    if (getType() != f64Ty) +      return emitOpError("expected destination type to be f64"); +    return success(); +  } +  // Everything else is a struct    Type dstType = LLVM::LLVMStructType::getLiteral(        getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));    if (getType() != dstType) @@ -1608,9 +1622,52 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,  }  //===----------------------------------------------------------------------===// +// getPtx methods +//===----------------------------------------------------------------------===// + +std::string NVVM::MBarrierInitOp::getPtx() { +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); +  return (addressSpace == NVVMMemorySpace::Shared) +             ? std::string("mbarrier.init.shared.b64 [%0], %1;") +             : std::string("mbarrier.init.b64 [%0], %1;"); +} + +//===----------------------------------------------------------------------===//  // getIntrinsicID/getIntrinsicIDAndArgs methods  //===----------------------------------------------------------------------===// +mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs( +    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { +  auto thisOp = cast<NVVM::MBarrierInitOp>(op); +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) +          .getAddressSpace(); +  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) +                               ? llvm::Intrinsic::nvvm_mbarrier_init_shared +                               : llvm::Intrinsic::nvvm_mbarrier_init; + +  // Fill the Intrinsic Args +  llvm::SmallVector<llvm::Value *> args; +  args.push_back(mt.lookupValue(thisOp.getAddr())); +  args.push_back(mt.lookupValue(thisOp.getCount())); + +  return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair MBarrierInvalOp::getIntrinsicIDAndArgs( +    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { +  auto thisOp = cast<NVVM::MBarrierInvalOp>(op); +  unsigned addressSpace = +      llvm::cast<LLVM::LLVMPointerType>(thisOp.getAddr().getType()) +          .getAddressSpace(); +  llvm::Intrinsic::ID id = (addressSpace == NVVMMemorySpace::Shared) +                               ? llvm::Intrinsic::nvvm_mbarrier_inval_shared +                               : llvm::Intrinsic::nvvm_mbarrier_inval; + +  return {id, {mt.lookupValue(thisOp.getAddr())}}; +} +  #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \    llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix | 
