aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/NVGPUToNVVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/NVGPUToNVVM')
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp53
1 files changed, 25 insertions, 28 deletions
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 905287e1..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -21,19 +21,17 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1106,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1282,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1367,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1392,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1401,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1468,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(