diff options
Diffstat (limited to 'mlir/lib/Dialect')
23 files changed, 221 insertions, 53 deletions
| diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 898d76c..980442e 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2751,7 +2751,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {            .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })            .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })            .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) -          .Default([](Operation *op) { return std::nullopt; }); +          .Default(std::nullopt);    if (!maybeKind) {      return std::nullopt;    } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index d9d6934..8655ed3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -95,12 +95,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,  /// Return the FuncOp called by `callOp`.  static FuncOp getCalledFunction(CallOpInterface callOp,                                  SymbolTableCollection &symbolTables) { -  SymbolRefAttr sym = -      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); -  if (!sym) -    return nullptr; -  return dyn_cast_or_null<FuncOp>( -      symbolTables.lookupNearestSymbolFrom(callOp, sym)); +  return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));  }  /// Return the FuncOp called by `callOp`. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index aa53f94..c233e24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -285,12 +285,8 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {  static func::FuncOp  getCalledFunction(func::CallOp callOp,                    mlir::SymbolTableCollection &symbolTable) { -  SymbolRefAttr sym = -      llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee()); -  if (!sym) -    return nullptr;    return dyn_cast_or_null<func::FuncOp>( -      symbolTable.lookupNearestSymbolFrom(callOp, sym)); +      callOp.resolveCallableInTable(&symbolTable));  }  /// Return "true" if the given function signature has tensor semantics. diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index 47740d3..e9da135 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@  add_mlir_dialect_library(MLIRControlFlowTransforms    BufferDeallocationOpInterfaceImpl.cpp    BufferizableOpInterfaceImpl.cpp +  StructuralTypeConversions.cpp    ADDITIONAL_HEADER_DIRS    ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp new file mode 100644 index 0000000..5e2a742 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/StructuralTypeConversions.cpp @@ -0,0 +1,169 @@ +//===- TypeConversion.cpp - Type Conversion of Unstructured Control Flow --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard and builtin dialects +// into the LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Helper function for converting branch ops. This function converts the +/// signature of the given block. If the new block signature is different from +/// `expectedTypes`, returns "failure". +static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, +                                            const TypeConverter *converter, +                                            Operation *branchOp, Block *block, +                                            TypeRange expectedTypes) { +  assert(converter && "expected non-null type converter"); +  assert(!block->isEntryBlock() && "entry blocks have no predecessors"); + +  // There is nothing to do if the types already match. +  if (block->getArgumentTypes() == expectedTypes) +    return block; + +  // Compute the new block argument types and convert the block. +  std::optional<TypeConverter::SignatureConversion> conversion = +      converter->convertBlockSignature(block); +  if (!conversion) +    return rewriter.notifyMatchFailure(branchOp, +                                       "could not compute block signature"); +  if (expectedTypes != conversion->getConvertedTypes()) +    return rewriter.notifyMatchFailure( +        branchOp, +        "mismatch between adaptor operand types and computed block signature"); +  return rewriter.applySignatureConversion(block, *conversion, converter); +} + +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { +  SmallVector<Value> result; +  for (const ValueRange &vals : values) +    llvm::append_range(result, vals); +  return result; +} + +/// Convert the destination block signature (if necessary) and change the +/// operands of the branch op. +struct BranchOpConversion : public OpConversionPattern<cf::BranchOp> { +  using OpConversionPattern<cf::BranchOp>::OpConversionPattern; + +  LogicalResult +  matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor, +                  ConversionPatternRewriter &rewriter) const override { +    SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands()); +    FailureOr<Block *> convertedBlock = +        getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), +                          TypeRange(ValueRange(flattenedAdaptor))); +    if (failed(convertedBlock)) +      return failure(); +    rewriter.replaceOpWithNewOp<cf::BranchOp>(op, flattenedAdaptor, +                                              *convertedBlock); +    return success(); +  } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the branch op. +struct CondBranchOpConversion : public OpConversionPattern<cf::CondBranchOp> { +  using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern; + +  LogicalResult +  matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor, +                  ConversionPatternRewriter &rewriter) const override { +    SmallVector<Value> flattenedAdaptorTrue = +        flattenValues(adaptor.getTrueDestOperands()); +    SmallVector<Value> flattenedAdaptorFalse = +        flattenValues(adaptor.getFalseDestOperands()); +    if (!llvm::hasSingleElement(adaptor.getCondition())) +      return rewriter.notifyMatchFailure(op, +                                         "expected single element condition"); +    FailureOr<Block *> convertedTrueBlock = +        getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), +                          TypeRange(ValueRange(flattenedAdaptorTrue))); +    if (failed(convertedTrueBlock)) +      return failure(); +    FailureOr<Block *> convertedFalseBlock = +        getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), +                          TypeRange(ValueRange(flattenedAdaptorFalse))); +    if (failed(convertedFalseBlock)) +      return failure(); +    rewriter.replaceOpWithNewOp<cf::CondBranchOp>( +        op, llvm::getSingleElement(adaptor.getCondition()), +        flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), +        *convertedTrueBlock, *convertedFalseBlock); +    return success(); +  } +}; + +/// Convert the destination block signatures (if necessary) and change the +/// operands of the switch op. +struct SwitchOpConversion : public OpConversionPattern<cf::SwitchOp> { +  using OpConversionPattern<cf::SwitchOp>::OpConversionPattern; + +  LogicalResult +  matchAndRewrite(cf::SwitchOp op, OpAdaptor adaptor, +                  ConversionPatternRewriter &rewriter) const override { +    // Get or convert default block. +    FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( +        rewriter, getTypeConverter(), op, op.getDefaultDestination(), +        TypeRange(adaptor.getDefaultOperands())); +    if (failed(convertedDefaultBlock)) +      return failure(); + +    // Get or convert all case blocks. +    SmallVector<Block *> caseDestinations; +    SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); +    for (auto it : llvm::enumerate(op.getCaseDestinations())) { +      Block *b = it.value(); +      FailureOr<Block *> convertedBlock = +          getConvertedBlock(rewriter, getTypeConverter(), op, b, +                            TypeRange(caseOperands[it.index()])); +      if (failed(convertedBlock)) +        return failure(); +      caseDestinations.push_back(*convertedBlock); +    } + +    rewriter.replaceOpWithNewOp<cf::SwitchOp>( +        op, adaptor.getFlag(), *convertedDefaultBlock, +        adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), +        caseDestinations, caseOperands); +    return success(); +  } +}; + +} // namespace + +void mlir::cf::populateCFStructuralTypeConversions( +    const TypeConverter &typeConverter, RewritePatternSet &patterns, +    PatternBenefit benefit) { +  patterns.add<BranchOpConversion, CondBranchOpConversion, SwitchOpConversion>( +      typeConverter, patterns.getContext(), benefit); +} + +void mlir::cf::populateCFStructuralTypeConversionTarget( +    const TypeConverter &typeConverter, ConversionTarget &target) { +  target.addDynamicallyLegalOp<cf::BranchOp, cf::CondBranchOp, cf::SwitchOp>( +      [&](Operation *op) { return typeConverter.isLegal(op->getOperands()); }); +} + +void mlir::cf::populateCFStructuralTypeConversionsAndLegality( +    const TypeConverter &typeConverter, RewritePatternSet &patterns, +    ConversionTarget &target, PatternBenefit benefit) { +  populateCFStructuralTypeConversions(typeConverter, patterns, benefit); +  populateCFStructuralTypeConversionTarget(typeConverter, target); +} diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index d2c2138..025d1ac 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -330,7 +330,7 @@ static Value getBase(Value v) {                v = op.getSrc();                return true;              }) -            .Default([](Operation *) { return false; }); +            .Default(false);      if (!shouldContinue)        break;    } @@ -354,7 +354,7 @@ static Value propagatesCapture(Operation *op) {        .Case([](memref::TransposeOp transpose) { return transpose.getIn(); })        .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(            [](auto op) { return op.getSrc(); }) -      .Default([](Operation *) { return Value(); }); +      .Default(nullptr);  }  /// Returns `true` if the given operation is known to capture the given value, @@ -371,7 +371,7 @@ static std::optional<bool> getKnownCapturingStatus(Operation *op, Value v) {        // These operations are known not to capture.        .Case([](memref::DeallocOp) { return false; })        // By default, we don't know anything. -      .Default([](Operation *) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// Returns `true` if the value may be captured by any of its users, i.e., if diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 3eae67f..2731069 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -698,7 +698,7 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,                         return structType.getBody()[memberIndex];                       return nullptr;                     }) -                   .Default(Type(nullptr)); +                   .Default(nullptr);    }  } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index cee943d..7d9058c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -1111,7 +1111,7 @@ memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,            .Case<IntegerType, FloatType>([](auto type) {              return type.getWidth() % 8 == 0 && type.getWidth() > 0;            }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!canConvertType)      return false; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index ac35eea..ce93d18 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -798,7 +798,7 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) {            // clang-format on            .Case<PtrLikeTypeInterface>(                [](Type type) { return isCompatiblePtrType(type); }) -          .Default([](Type) { return false; }); +          .Default(false);    if (!result)      compatibleTypes.erase(type); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cbc565b..3dc45ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1474,6 +1474,8 @@ void MapOp::getAsmBlockArgumentNames(Region ®ion,                                       OpAsmSetValueNameFn setNameFn) {    for (Value v : getRegionInputArgs())      setNameFn(v, "in"); +  for (Value v : getRegionOutputArgs()) +    setNameFn(v, "init");  }  void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1495,14 +1497,14 @@ void MapOp::build(    if (bodyBuild)      buildGenericRegion(builder, result.location, *result.regions.front(), -                       inputs, /*outputs=*/{}, bodyBuild); +                       inputs, /*outputs=*/{init}, bodyBuild);  }  static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,                                   const OperationName &payloadOpName,                                   const NamedAttrList &payloadOpAttrs,                                   ArrayRef<Value> operands, -                                 bool initFirst = false) { +                                 bool initFirst = false, bool mapInit = true) {    OpBuilder b(parser.getContext());    Region *body = result.addRegion();    Block &block = body->emplaceBlock(); @@ -1516,12 +1518,13 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,    // If initFirst flag is enabled, we consider init as the first position of    // payload operands.    if (initFirst) { -    payloadOpOperands.push_back(block.getArguments().back()); +    if (mapInit) +      payloadOpOperands.push_back(block.getArguments().back());      for (const auto &arg : block.getArguments().drop_back())        payloadOpOperands.push_back(arg);    } else {      payloadOpOperands = {block.getArguments().begin(), -                         block.getArguments().end()}; +                         block.getArguments().end() - int(!mapInit)};    }    Operation *payloadOp = b.create( @@ -1553,8 +1556,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    if (payloadOpName.has_value()) {      if (!result.operands.empty())        addBodyWithPayloadOp(parser, result, payloadOpName.value(), -                           payloadOpAttrs, -                           ArrayRef(result.operands).drop_back()); +                           payloadOpAttrs, ArrayRef(result.operands), false, +                           false);      else        result.addRegion();    } else { @@ -1570,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {    return success();  } -static bool canUseShortForm(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false, +                            bool mapInit = true) { +  // `intFirst == true` implies that we want to map init arg +  if (initFirst && !mapInit) +    return false;    // Check if the body can be printed in short form. The following 4 conditions    // must be satisfied: @@ -1582,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {    // 2) The payload op must have the same number of operands as the number of    //    block arguments.    if (payload.getNumOperands() == 0 || -      payload.getNumOperands() != body->getNumArguments()) +      payload.getNumOperands() != body->getNumArguments() - int(!mapInit))      return false;    // 3) If `initFirst` is true (e.g., for reduction ops), the init block @@ -1600,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {      }    } else {      for (const auto &[operand, bbArg] : -         llvm::zip(payload.getOperands(), body->getArguments())) { +         llvm::zip(payload.getOperands(), +                   body->getArguments().drop_back(int(!mapInit)))) {        if (bbArg != operand)          return false;      } @@ -1632,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {  void MapOp::print(OpAsmPrinter &p) {    Block *mapper = getBody(); -  bool useShortForm = canUseShortForm(mapper); +  bool useShortForm = +      canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);    if (useShortForm) {      printShortForm(p, &mapper->getOperations().front());    } @@ -1658,11 +1667,13 @@ LogicalResult MapOp::verify() {    auto *bodyBlock = getBody();    auto blockArgs = bodyBlock->getArguments(); -  // Checks if the number of `inputs` match the arity of the `mapper` region. -  if (getInputs().size() != blockArgs.size()) +  // Checks if the number of `inputs` + `init` match the arity of the `mapper` +  // region. +  if (getInputs().size() + 1 != blockArgs.size())      return emitOpError() << "expects number of operands to match the arity of "                              "mapper, but got: " -                         << getInputs().size() << " and " << blockArgs.size(); +                         << getInputs().size() + 1 << " and " +                         << blockArgs.size();    // The parameters of mapper should all match the element type of inputs.    for (const auto &[bbArgType, inputArg] : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8b89244..b09112b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -4499,7 +4499,7 @@ DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(              maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);              return true;            }) -          .Default([&](Operation *op) { return false; }); +          .Default(false);    if (!supported) {      DiagnosedSilenceableFailure diag = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 3e31393..75bb175 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -31,10 +31,8 @@ using namespace mlir;  using namespace mlir::linalg;  static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp) { -  // Bailout if `linalgOp` is already a generic or a linalg.map. We cannot -  // trivially generalize a `linalg.map`, as it does not use the output as -  // region arguments in the block. -  if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp)) +  // Bailout if `linalgOp` is already a generic. +  if (isa<GenericOp>(linalgOp))      return failure();    // Check if the operation has exactly one region.    if (linalgOp->getNumRegions() != 1) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index f05ffa8..6519c4f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -322,7 +322,7 @@ promoteSubViews(ImplicitLocOpBuilder &b,                  tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0));                return complex::CreateOp::create(b, t, tmp, tmp);              }) -            .Default([](auto) { return Value(); }); +            .Default(nullptr);      if (!fillVal)        return failure();      linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index 27ccf3c..6becc1f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -89,7 +89,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,                  ValueRange{input, collapsedKernel, iZp, kZp},                  ValueRange{collapsedInit}, stride, dilation);            }) -          .Default([](Operation *op) { return nullptr; }); +          .Default(nullptr);    if (!newConv)      return failure();    for (auto attr : preservedAttrs) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 9d62491..cb6199f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -656,7 +656,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) {            [&](auto op) { return CombiningKind::MUL; })        .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })        .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; }) -      .Default([&](auto op) { return std::nullopt; }); +      .Default(std::nullopt);  }  /// Check whether `outputOperand` is a reduction with a single combiner @@ -3911,21 +3911,21 @@ struct Conv1DGenerator      Value lhs = vector::TransferReadOp::create(          rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); -    auto maybeMaskedLhs = maybeMaskXferOp( +    auto *maybeMaskedLhs = maybeMaskXferOp(          lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());      // Read rhs slice of size {kw, c} @ [0, 0].      Value rhs = vector::TransferReadOp::create(          rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); -    auto maybeMaskedRhs = maybeMaskXferOp( +    auto *maybeMaskedRhs = maybeMaskXferOp(          rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());      // Read res slice of size {n, w, c} @ [0, 0, 0].      Value res = vector::TransferReadOp::create(          rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},          /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); -    auto maybeMaskedRes = maybeMaskXferOp( +    auto *maybeMaskedRes = maybeMaskXferOp(          resType.getShape(), resType.getScalableDims(), res.getDefiningOp());      //===------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 1208fdd..e685089 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -104,7 +104,7 @@ static Value getTargetMemref(Operation *op) {                       vector::MaskedStoreOp, vector::TransferReadOp,                       vector::TransferWriteOp>(            [](auto op) { return op.getBase(); }) -      .Default([](auto) { return Value{}; }); +      .Default(nullptr);  }  template <typename T> diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index 4ebd90d..d380c46 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -55,7 +55,7 @@ static bool isShapePreserving(ForOp forOp, int64_t arg) {                               ? forOp.getInitArgs()[opResult.getResultNumber()]                               : Value();                  }) -                .Default([&](auto op) { return Value(); }); +                .Default(nullptr);    }    return false;  } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 0c8114d..938952e 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -346,7 +346,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {        llvm::TypeSwitch<Type, Type>(getType())            .Case<spirv::CooperativeMatrixType>(                [](auto coopType) { return coopType.getElementType(); }) -          .Default([](Type) { return nullptr; }); +          .Default(nullptr);    // Case 1. -- matrices.    if (coopElementType) { @@ -1708,7 +1708,7 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {        llvm::TypeSwitch<Type, Type>(getMatrix().getType())            .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(                [](auto matrixType) { return matrixType.getElementType(); }) -          .Default([](Type) { return nullptr; }); +          .Default(nullptr);    assert(elementType && "Unhandled type"); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index f895807..d1e275d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -731,7 +731,7 @@ std::optional<int64_t> SPIRVType::getSizeInBytes() {            return *elementSize * type.getNumElements();          return std::nullopt;        }) -      .Default(std::optional<int64_t>()); +      .Default(std::nullopt);  }  //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 88e1ab6..cb9b7f6 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1467,7 +1467,7 @@ mlir::spirv::getNativeVectorShape(Operation *op) {    return TypeSwitch<Operation *, std::optional<SmallVector<int64_t>>>(op)        .Case<vector::ReductionOp, vector::TransposeOp>(            [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); }) -      .Default([](Operation *) { return std::nullopt; }); +      .Default(std::nullopt);  }  LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index ac72002..110bfdc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -41,10 +41,6 @@  using namespace mlir;  using namespace mlir::tensor; -using llvm::divideCeilSigned; -using llvm::divideFloorSigned; -using llvm::mod; -  /// Materialize a single constant operation from a given attribute value with  /// the desired resultant type.  Operation *TensorDialect::materializeConstant(OpBuilder &builder, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bce964e..c607ece 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -579,6 +579,7 @@ static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,        linalg::MapOp::create(rewriter, loc, tensorType, /*inputs=*/ValueRange(),                              /*init=*/tensorDestination);    Block &linalgBody = linalgOp.getMapper().emplaceBlock(); +  linalgBody.addArgument(tensorType.getElementType(), loc);    // Create linalg::IndexOps.    rewriter.setInsertionPointToStart(&linalgBody); @@ -1068,6 +1069,7 @@ struct SplatOpInterface                                            /*inputs=*/ValueRange(),                                            /*init=*/*tensorAlloc);      Block &linalgBody = linalgOp.getMapper().emplaceBlock(); +    linalgBody.addArgument(tensorType.getElementType(), loc);      // Create linalg::IndexOps.      rewriter.setInsertionPointToStart(&linalgBody); diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp index 69e649d..bc4f5a5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp @@ -189,7 +189,7 @@ struct PadOpToConstant final : public OpRewritePattern<PadOp> {                return constantFoldPadOp<llvm::APInt>(                    rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);              }) -            .Default(Value()); +            .Default(nullptr);      if (!newOp)        return rewriter.notifyMatchFailure(padTensorOp, | 
