aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp177
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp85
2 files changed, 256 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 1f239aa..519d9c8 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
emitc::AssignOp::create(rewriter, loc, var, value);
}
-SmallVector<Value> loadValues(const SmallVector<Value> &variables,
+SmallVector<Value> loadValues(ArrayRef<Value> variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
- scf::YieldOp yield) {
+ scf::YieldOp yield, bool createYield = true) {
Location loc = yield.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);
SmallVector<Value> yieldOperands;
- if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
+ if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
- }
assignValues(yieldOperands, resultVariables, rewriter, loc);
@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+// Lower scf::while to emitc::do using mutable variables to maintain loop state
+// across iterations. The do-while structure ensures the condition is evaluated
+// after each iteration, matching SCF while semantics.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables)))
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> loopVariables;
+ if (failed(createVariablesForLoopCarriedValues(
+ whileOp, rewriter, loopVariables, loc, context)))
+ return failure();
+
+ if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
+ rewriter, loc)))
+ return failure();
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Load the final result values from result variables.
+ SmallVector<Value> finalResults =
+ loadValues(resultVariables, rewriter, loc);
+ rewriter.replaceOp(whileOp, finalResults);
+
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ LogicalResult createVariablesForLoopCarriedValues(
+ WhileOp whileOp, ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &loopVars, Location loc,
+ MLIRContext *context) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(whileOp);
+
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ Type convertedType = getTypeConverter()->convertType(init.getType());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
+
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(convertedType), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ loopVars.push_back(var);
+ }
+
+ return success();
+ }
+
+ // Lower scf.while to emitc.do.
+ LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
+ ArrayRef<Value> resultVars, MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ // Create a global boolean variable to store the loop condition state.
+ Type i1Type = IntegerType::get(context, 1);
+ auto globalCondition =
+ rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
+ Value conditionVal = globalCondition.getResult();
+
+ auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+
+ // Convert region types to match the target dialect type system.
+ if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
+ *getTypeConverter(), nullptr)) ||
+ failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
+ *getTypeConverter(), nullptr))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "region types conversion failed");
+ }
+
+ // Prepare the before region (condition evaluation) for merging.
+ Block *beforeBlock = &whileOp.getBefore().front();
+ Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
+ rewriter.setInsertionPointToStart(bodyBlock);
+
+ // Load current variable values to use as initial arguments for the
+ // condition block.
+ SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
+ rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
+
+ Operation *condTerminator =
+ loweredDo.getBodyRegion().back().getTerminator();
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+
+ // Update result variables with values from scf::condition.
+ SmallVector<Value> conditionArgs;
+ for (Value arg : condOp.getArgs()) {
+ conditionArgs.push_back(rewriter.getRemappedValue(arg));
+ }
+ assignValues(conditionArgs, resultVars, rewriter, loc);
+
+ // Convert scf.condition to condition variable assignment.
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+
+ // Wrap body region in conditional to preserve scf semantics. Only create
+ // ifOp if after-region is non-empty.
+ if (whileOp.getAfterBody()->getOperations().size() > 1) {
+ auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+
+ // Prepare the after region (loop body) for merging.
+ Block *afterBlock = &whileOp.getAfter().front();
+ Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
+
+ // Replacement values for after block using condition op arguments.
+ SmallVector<Value> afterReplacingValues;
+ for (Value arg : condOp.getArgs())
+ afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
+
+ rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
+
+ if (failed(lowerYield(whileOp, loopVars, rewriter,
+ cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
+ return failure();
+ }
+
+ rewriter.eraseOp(condOp);
+
+ // Create condition region that loads from the flag variable.
+ Region &condRegion = loweredDo.getConditionRegion();
+ Block *condBlock = rewriter.createBlock(&condRegion);
+ rewriter.setInsertionPointToStart(condBlock);
+
+ auto exprOp = rewriter.create<emitc::ExpressionOp>(
+ loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
+
+ // Set up the expression block to load the condition variable.
+ exprBlock->addArgument(conditionVal.getType(), loc);
+ rewriter.setInsertionPointToStart(exprBlock);
+
+ // Load the condition value and yield it as the expression result.
+ Value cond =
+ rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
+ rewriter.create<emitc::YieldOp>(loc, cond);
+
+ // Yield the expression as the condition region result.
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<emitc::YieldOp>(loc, exprOp);
+
+ return success();
+ }
+};
+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
+ patterns.add<WhileLowering>(typeConverter, patterns.getContext());
}
void SCFToEmitCPass::runOnOperation() {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
+ target
+ .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 57877b8..f449d90 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}
+static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
@@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}
+static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
@@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
+ std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
@@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};
+
+template <typename OpType>
+class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
+ // Get OpenCL function name
+ // https://registry.khronos.org/OpenCL/extensions/
+ // intel/cl_intel_subgroup_local_block_io.html
+ std::string funcName{"intel_sub_group_block_"};
+ // Value or Result type can be vector or scalar
+ Type valOrResTy;
+ if constexpr (isStore) {
+ funcName += "write_u";
+ valOrResTy = op.getVal().getType();
+ } else {
+ funcName += "read_u";
+ valOrResTy = op.getType();
+ }
+ // Get element type of the vector/scalar
+ VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
+ Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
+ funcName += getTypeMangling(elemType);
+ if (vecTy)
+ funcName += std::to_string(vecTy.getNumElements());
+ SmallVector<Type, 2> argTypes{};
+ // XeVM BlockLoad/StoreOp always use signless integer types
+ // but OpenCL builtins expect unsigned types
+ // use unsigned types for mangling
+ SmallVector<bool, 2> isUnsigned{};
+ // arg0: pointer to the src/dst address
+ // arg1 - only if store : vector to store
+ // Prepare arguments
+ SmallVector<Value, 2> args{};
+ args.push_back(op.getPtr());
+ argTypes.push_back(op.getPtr().getType());
+ isUnsigned.push_back(true);
+ Type retType;
+ if constexpr (isStore) {
+ args.push_back(op.getVal());
+ argTypes.push_back(op.getVal().getType());
+ isUnsigned.push_back(true);
+ retType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ } else {
+ retType = valOrResTy;
+ }
+ funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
+ "PU3AS" +
+ std::to_string(op.getPtr().getType().getAddressSpace());
+ funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
+ if constexpr (isStore)
+ funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
+ LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
+
+ LLVM::CallOp call =
+ createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
+ {}, funcAttr, op.getOperation());
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op)) {
+ call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ }
+ if constexpr (isStore)
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, call->getResult(0));
+ return success();
+ }
+};
+
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
- LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
+ BlockLoadStore1DToOCLPattern<BlockLoadOp>,
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ patterns.getContext());
}
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {