diff options
author | Adam Siemieniuk <adam.siemieniuk@intel.com> | 2025-06-12 13:45:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-12 13:45:19 +0200 |
commit | d698ede748e66f5519cb8481abc2df89a994a059 (patch) | |
tree | d57ae736dc085ad8ad632ae89039e1731be62f63 | |
parent | 013034cd0f5ae19ef02fc35a83362874e727f13c (diff) | |
download | llvm-d698ede748e66f5519cb8481abc2df89a994a059.zip llvm-d698ede748e66f5519cb8481abc2df89a994a059.tar.gz llvm-d698ede748e66f5519cb8481abc2df89a994a059.tar.bz2 |
[mlir][amx] Restore conversion interface for AMX (#143871)
Restores mistakenly removed AMX interface which ensures that the custom
tile type is converted to its LLVM equivalent within other operations
such as control flow.
Fix after #140559
-rw-r--r-- | mlir/include/mlir/Dialect/AMX/Transforms.h | 3 | ||||
-rw-r--r-- | mlir/include/mlir/InitAllExtensions.h | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp | 19 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/amx.mlir | 20 |
4 files changed, 44 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h index 4a751d9..7391ec2 100644 --- a/mlir/include/mlir/Dialect/AMX/Transforms.h +++ b/mlir/include/mlir/Dialect/AMX/Transforms.h @@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, /// intrinsics. void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); +/// Register LLVM conversion interface for AMX dialect. +void registerConvertAMXToLLVMInterface(DialectRegistry ®istry); + } // namespace mlir #endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 7dcbabe..f356b91b 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -32,6 +32,7 @@ #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" @@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); + registerConvertAMXToLLVMInterface(registry); gpu::registerConvertGpuToLLVMInterface(registry); NVVM::registerConvertGpuToNVVMInterface(registry); vector::registerConvertVectorToLLVMInterface(registry); diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 7471dc7..37aebc9 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns( void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { target.addIllegalDialect<AMXDialect>(); } + +namespace { +/// Implement the interface to convert AMX to LLVM. +struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns); + } +}; +} // namespace + +void mlir::registerConvertAMXToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) { + dialect->addInterfaces<AMXToLLVMDialectInterface>(); + }); +} diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index 0944750..abdf2fe 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>, amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32> return } + +// CHECK-LABEL: define void @amx_tile_type_through_cf +func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>, + %idx: index, %cond: i1) { + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: // pred: ^bb0 + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8> + cf.br ^bb3(%0 : !amx.tile<16x64xi8>) +^bb2: // pred: ^bb0 + // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) + %1 = amx.tile_zero : !amx.tile<16x64xi8> + cf.br ^bb3(%1 : !amx.tile<16x64xi8>) +^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2 + cf.br ^bb4 +^bb4: // pred: ^bb3 + // CHECK: call void @llvm.x86.tilestored64.internal + amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8> + return +} |