aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp')
-rw-r--r--mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp7
1 files changed, 3 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 809d634..9e5ea93 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
const WarpMatrixInfo &fragmentType) {
Type elementType = fragmentType.vectorType.getElementType();
ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
- FailureOr<nvgpu::FragmentElementInfo> regInfo =
- getMmaSyncRegisterType(fragmentType);
+ FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
if (failed(regInfo))
return failure();
@@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
(logicalValueIdDim % elementsPerRegister)});
}
-FailureOr<nvgpu::LdMatrixParams>
-nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
+FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type,
+ bool transpose) {
LdMatrixParams params;
Type elType = type.vectorType.getElementType();
params.fragmentType = type.vectorType;