//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements an ODS (and C++) generator from a YAML form // derived from the mathematical expression of linalg named ops. Typically a // math oriented DSL will be used to export the essential representation to // this form, and maintaining the SOT at the math level (versus recreating it // in MLIR) is deemed to have systemic value. // //===----------------------------------------------------------------------===// #include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/YAMLTraits.h" #include using namespace mlir; using llvm::yaml::Input; using llvm::yaml::MappingTraits; using llvm::yaml::ScalarEnumerationTraits; using llvm::yaml::ScalarTraits; #define DEBUG_TYPE "linalg-ods-gen" //===----------------------------------------------------------------------===// // Mapping structs (correspond to data types in the YAML description). // TODO: Since this is a schema/part of the contract, it should be moved to // a real header. //===----------------------------------------------------------------------===// namespace { struct LinalgYAMLContext { MLIRContext *mlirContext; }; struct LinalgOpMetadata { std::string name; std::string cppClassName; std::optional doc; SmallVector implements; SmallVector defines; }; struct SerializedAffineMap { AffineMapAttr affineMapAttr; AffineMap affineMap() { return affineMapAttr.getValue(); } }; enum class LinalgOperandDefKind { InputTensor, Scalar, OutputTensor, IndexAttr, UnaryFnAttr, BinaryFnAttr, TernaryFnAttr, TypeFnAttr }; struct LinalgOperandDef { std::string name; LinalgOperandDefKind kind; std::optional typeVar; std::optional shapeMap; std::optional indexAttrMap; std::optional> defaultIndices; std::optional defaultFn; }; enum class LinalgIteratorTypeDef { parallel, reduction, }; struct LinalgIndexingMapsConfig { std::optional> staticIndexingMaps; }; struct ScalarExpression; enum class ScalarFnKind { Unary, Binary, Ternary, Type }; struct ScalarFn { ScalarFnKind kind; std::optional fnName; std::optional attrName; std::optional typeVar; // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; }; struct ScalarExpression { std::optional arg; std::optional constant; std::optional index; std::optional scalarFn; }; struct ScalarAssign { std::string arg; ScalarExpression value; }; struct LinalgStructuredOpConfig { SmallVector args; LinalgIndexingMapsConfig indexingMaps; SmallVector iteratorTypes; std::vector assignments; }; struct LinalgOpConfig { std::optional metadata; std::optional structuredOp; }; } // namespace //===----------------------------------------------------------------------===// // Mapping traits. //===----------------------------------------------------------------------===// LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgOperandDef) LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap) LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign) LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression) LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig) namespace llvm { namespace yaml { /// Top-level type containing op metadata and one of a concrete op type. /// Currently, the only defined op type is `structured_op` (maps to /// `LinalgStructuredOpConfig`). template <> struct MappingTraits { static void mapping(IO &io, LinalgOpConfig &info) { io.mapOptional("metadata", info.metadata); io.mapOptional("structured_op", info.structuredOp); } }; /// A structured op models (at most) a single contraction by modeling /// - A list of named arguments (`LinalgOperandDef`), which can be inputs, /// outputs, or index attributes. /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). template <> struct MappingTraits { static void mapping(IO &io, LinalgStructuredOpConfig &info) { io.mapRequired("args", info.args); io.mapRequired("indexing_maps", info.indexingMaps); io.mapRequired("iterator_types", info.iteratorTypes); io.mapRequired("assignments", info.assignments); } }; /// Maps a named tensor, scalar or attribute argument to an operation, /// consisting of: /// - `name`: Must be unique within the operation. /// - `usage`: How the argument is used (input, output, attribute, etc). /// - `type_var`: The symbolic type variable that binds to the element or self /// type of the tensor or scalar argument, respectively. /// - `shape_map`: An optional AffineMap from all op symbols to the shape of /// the argument. Only tensor arguments have a `shape_map`. Each shape must /// be normalized over the same list of symbols and have no dimension /// inputs. /// - `index_attr_map`: An optional AffineMap from all op symbols to the /// index attribute symbols. During op creation these symbols are replaced /// by the corresponding `name` index attribue values. Only index attribute /// arguments have an `index_attr_map`. /// - `default_indices`: An optional default initialization for index /// attribute arguments. /// - `default_fn`: An optional default initialization for function attribute /// arguments. template <> struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("kind", info.kind); io.mapOptional("type_var", info.typeVar); io.mapOptional("shape_map", info.shapeMap); io.mapOptional("index_attr_map", info.indexAttrMap); io.mapOptional("default_indices", info.defaultIndices); io.mapOptional("default_fn", info.defaultFn); } }; /// Usage enum for a named argument. template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefKind &value) { io.enumCase(value, "input_tensor", LinalgOperandDefKind::InputTensor); io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar); io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor); io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr); io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr); io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr); io.enumCase(value, "ternary_fn_attr", LinalgOperandDefKind::TernaryFnAttr); io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr); } }; /// Iterator type enum. template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgIteratorTypeDef &value) { io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); } }; /// Metadata about the op (name, C++ name, and documentation). template <> struct MappingTraits { static void mapping(IO &io, LinalgOpMetadata &info) { io.mapRequired("name", info.name); io.mapRequired("cpp_class_name", info.cppClassName); io.mapOptional("doc", info.doc); io.mapOptional("implements", info.implements); io.mapOptional("defines", info.defines); } }; /// How the ops indexing maps are produced. Must be one of: /// - static_indexing_maps: A static list of AffineMaps, possibly with /// some symbols that bind to attributes of the op. Each indexing map must /// be normalized over the same list of dimensions, and its symbols must /// match the symbols for argument shapes. template <> struct MappingTraits { static void mapping(IO &io, LinalgIndexingMapsConfig &info) { io.mapOptional("static_indexing_maps", info.staticIndexingMaps); } }; /// Models an assignment to a named output. /// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). template <> struct MappingTraits { static void mapping(IO &io, ScalarAssign &info) { io.mapRequired("arg", info.arg); io.mapRequired("value", info.value); } }; /// A scalar expression (RHS of an assignment). Must be one of: /// - `scalar_arg`: An operation argument. /// - `scalar_const`: A constant definition. /// - `scalar_index`: An iteration index. /// - `scalar_fn`: A named function (see `ScalarFn`). template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); io.mapOptional("scalar_index", info.index); io.mapOptional("scalar_fn", info.scalarFn); } }; /// Scalar function kind enum. template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, ScalarFnKind &value) { io.enumCase(value, "unary", ScalarFnKind::Unary); io.enumCase(value, "binary", ScalarFnKind::Binary); io.enumCase(value, "ternary", ScalarFnKind::Ternary); io.enumCase(value, "type", ScalarFnKind::Type); } }; /// A scalar expression that evaluates a named function. /// Functions are generally "math" level and type polymorphic. Builtin /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` template <> struct MappingTraits { static void mapping(IO &io, ScalarFn &info) { io.mapRequired("kind", info.kind); io.mapOptional("fn_name", info.fnName); io.mapOptional("attr_name", info.attrName); io.mapOptional("type_var", info.typeVar); io.mapRequired("operands", info.operands); } }; /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. template <> struct ScalarTraits { static void output(const SerializedAffineMap &value, void *rawYamlContext, raw_ostream &out) { assert(value.affineMapAttr); value.affineMapAttr.print(out); } static StringRef input(StringRef scalar, void *rawYamlContext, SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); if (auto attr = dyn_cast_or_null( mlir::parseAttribute(scalar, yamlContext->mlirContext))) value.affineMapAttr = attr; else if (!value.affineMapAttr || !isa(value.affineMapAttr)) return "could not parse as an affine map attribute"; return StringRef(); } static QuotingType mustQuote(StringRef) { return QuotingType::None; } }; } // namespace yaml } // namespace llvm namespace { //===----------------------------------------------------------------------===// // Generation utilities //===----------------------------------------------------------------------===// class GenerationContext { public: GenerationContext(MLIRContext *context, raw_ostream *odsOut, raw_ostream *defnOut) : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), defnOut(defnOut) {} MLIRContext *getContext() { return context; } void setLoc(Location loc) { this->loc = loc; } Location getLoc() { return loc; } bool shouldGenerateOds() { return odsOut; } bool shouldGenerateDefns() { return defnOut; } raw_ostream &odss() { assert(odsOut && "ODS stream not defined"); return *odsOut; } raw_ostream &defns() { assert(defnOut && "Definition stream not defined"); return *defnOut; } private: MLIRContext *context; Location loc; raw_ostream *odsOut; raw_ostream *defnOut; }; } // namespace static std::string generateCppExpression(SerializedAffineMap self, StringRef contextName) { std::string printedStr; llvm::raw_string_ostream printedSs(printedStr); self.affineMapAttr.print(printedSs); printedSs.flush(); static const char exprFormat[] = R"FMT(llvm::cast(mlir::parseAttribute("{0}", {1})).getValue())FMT"; return llvm::formatv(exprFormat, printedStr, contextName); } template static std::string interleaveToString(Container &container, StringRef separator) { std::string result; llvm::raw_string_ostream ss(result); llvm::interleave(container, ss, separator); ss.flush(); return result; } static std::optional findTensorDefArgIndex(StringRef name, SmallVectorImpl &args) { for (const auto &it : llvm::enumerate(args)) { if (it.value().name == name) return it.index(); } return std::nullopt; } // Try to map the TypeVar to a predefined or an argument type. static std::optional findTypeValue(StringRef typeVar, SmallVectorImpl &args) { // Handle all predefined types. if (typeVar == "I32") return std::string("helper.getIntegerType(32)"); if (typeVar == "I64") return std::string("helper.getIntegerType(64)"); if (typeVar == "F32") return std::string("helper.getFloat32Type()"); if (typeVar == "F64") return std::string("helper.getFloat64Type()"); // Search all argument types. for (const auto &it : llvm::enumerate(args)) { if (it.value().kind != LinalgOperandDefKind::InputTensor && it.value().kind != LinalgOperandDefKind::Scalar && it.value().kind != LinalgOperandDefKind::OutputTensor) continue; if (*it.value().typeVar == typeVar) return llvm::formatv("block.getArgument({0}).getType()", it.index()) .str(); } return std::nullopt; } static ScalarAssign *findAssignment(StringRef name, std::vector &assignments) { for (auto &assign : assignments) { if (assign.arg == name) return &assign; } return nullptr; } // Return true if the operand is a function attribute. static bool isFunctionAttribute(LinalgOperandDefKind kind) { return kind == LinalgOperandDefKind::UnaryFnAttr || kind == LinalgOperandDefKind::BinaryFnAttr || kind == LinalgOperandDefKind::TernaryFnAttr || kind == LinalgOperandDefKind::TypeFnAttr; } // Return true if the operand is an attribute. static bool isAttribute(LinalgOperandDefKind kind) { return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind); } // Get the enum name for the given operand kind. std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { switch (kind) { case LinalgOperandDefKind::UnaryFnAttr: return std::string("UnaryFn"); case LinalgOperandDefKind::BinaryFnAttr: return std::string("BinaryFn"); case LinalgOperandDefKind::TernaryFnAttr: return std::string("TernaryFn"); case LinalgOperandDefKind::TypeFnAttr: return std::string("TypeFn"); default: break; } llvm_unreachable("unsupported function attribute kind"); } // Get the enum name for the given function kind. std::string convertFunctionKindToEnumName(ScalarFnKind kind) { switch (kind) { case ScalarFnKind::Unary: return std::string("UnaryFn"); case ScalarFnKind::Binary: return std::string("BinaryFn"); case ScalarFnKind::Ternary: return std::string("TernaryFn"); case ScalarFnKind::Type: return std::string("TypeFn"); } llvm_unreachable("unsupported function kind"); } //===----------------------------------------------------------------------===// // Templates //===----------------------------------------------------------------------===// // A single line banner format. Parameters: // {0}: Single line comment static const char bannerFormat[] = R"FMT( //===----------------------------------------------------------------------===// // {0} //===----------------------------------------------------------------------===// )FMT"; //===----------------------------------------------------------------------===// // Named generic op generation. // These ops map at most a single contraction that complies with the limitations // of a linalg.generic. //===----------------------------------------------------------------------===// // Template for Linalg named ops' ODS definitions. Parameters: // {0}: ODS/C++ op name // {1}: assembly op mnemonic // {2}: op interface list // {3}: documentation (summary + description) // {4}: op attribute list // {5}: builder methods taking standalone attribute parameters // {6}: additional method defintions // {7}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = R"FMT( //===----------------------------------------------------------------------===// // Op definition for {0} //===----------------------------------------------------------------------===// def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], /*extraInterfaces=*/[{2}])> { {3} let arguments = (ins Variadic:$inputs, Variadic:$outputs{4} ); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); let skipDefaultBuilders = 1; let builders = [ OpBuilder< (ins "ValueRange":$inputs, "ValueRange":$outputs, CArg<"ArrayRef", "{{}">:$attributes), [{{ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs, attributes, {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, CArg<"ArrayRef", "{{}">:$attributes), [{{ buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, attributes, {0}::getRegionBuilder()); }]>, OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, CArg<"ArrayRef", "{{}">:$attributes), [{{ $_state.addOperands(operands); $_state.addAttributes(attributes); $_state.addTypes(resultTensorTypes); (void)$_state.addRegion(); }]> {5} ]; let hasCustomAssemblyFormat = 1; let hasFolder = 1; {6} let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. SmallVector getIteratorTypesArray(); ArrayAttr getIndexingMaps(); static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs); static std::function)> getRegionBuilder() {{ return regionBuilder; } ::mlir::MutableOperandRange getDpsInitsMutable() {{ return getOutputsMutable(); } // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); {7} }]; } )FMT"; // Builder method taking attribute parameters. Parameters: // {0}: Class name // {1}: Comma interleaved attribute parameters // {2}: Attribute initialization static const char structuredOpBuilderFormat[] = R"FMT( , OpBuilder< (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, {1}, CArg<"ArrayRef", "{{}">:$attributes), [{{ {2} buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs, attributes, {0}::getRegionBuilder()); }]> )FMT"; // The getIteratorTypesArray() method for structured ops. Parameters: // {0}: Class name // {1}: Comma interleaved iterator type names. static const char structuredOpIteratorTypesFormat[] = R"FMT( SmallVector {0}::getIteratorTypesArray() {{ return SmallVector{{ {1} }; } )FMT"; // The getIteratorTypesArray() method for rank polymorphic structured ops. // Parameters: // {0}: Class name static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( SmallVector {0}::getIteratorTypesArray() {{ int64_t rank = getRank(getDpsInitOperand(0)); return SmallVector(rank, utils::IteratorType::parallel); } )FMT"; // The indexing_maps() method for structured ops. Parameters: // {0}: Class name // {1}: Comma-separated list of dimension variable names. // {2}: Statements static const char structuredOpIndexingMapsFormat[] = R"FMT( ArrayAttr {0}::getIndexingMaps() {{ static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); if (cached) return cached; MLIRContext *context = getContext(); auto symbolBindings = getSymbolBindings(*this); SmallVector maps; {2} cached = Builder(context).getAffineMapArrayAttr(maps); getOperation()->setAttr(memoizeAttr, cached); return cached; } )FMT"; // The indexing_maps() method for rank polymorphic structured ops. Parameters: // {0}: Class name static const char rankPolyStructuredOpIndexingMapsFormat[] = R"FMT( ArrayAttr {0}::getIndexingMaps() {{ MLIRContext *context = getContext(); AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( getNumParallelLoops(), context); SmallVector indexingMaps; for (OpOperand &opOperand : getOperation()->getOpOperands()) indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); } )FMT"; // Implementations of fold and getEffects. // Parameters: // {0}: Class name const char structuredOpFoldersFormat[] = R"FMT( LogicalResult {0}::fold(FoldAdaptor, SmallVectorImpl &) {{ return memref::foldMemRefCast(*this); } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ if (hasPureTensorSemantics()) return; getGenericEffectsImpl(effects, cast(getOperation())); } )FMT"; // Implementation of parse/print. // Parameters: // {0}: Class name static const char structuredOpParserFormat[] = R"FMT( ParseResult {0}::parse(OpAsmParser &parser, OperationState &result) {{ return ::parseNamedStructuredOp(parser, result, {0}::getNumRegionArgs(), {0}::getRegionBuilder()); } void {0}::print(OpAsmPrinter &p) {{ ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs()); } )FMT"; static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, GenerationContext &genContext) { if (!genContext.shouldGenerateOds()) return success(); raw_ostream &os = genContext.odss(); std::string interfaceNameList; std::string attrList; std::string attrMethods; std::string attrBuilder; std::string doc; if (opConfig.metadata->doc) { static const char structuredOpDocFmt[] = R"FMT( let summary = [{{{0}}]; let description = [{{{1}}]; )FMT"; StringRef summary, description; std::tie(summary, description) = StringRef(*opConfig.metadata->doc).trim().split("\n\n"); doc = llvm::formatv(structuredOpDocFmt, summary.trim(), description.trim()); } interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); std::string definitionList; for (const std::string &definition : opConfig.metadata->defines) { static const char definitionFmt[] = "let {0} = 1;\n"; definitionList.append(llvm::formatv(definitionFmt, definition)); } if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return isAttribute(arg.kind); })) { SmallVector attrDefs; SmallVector attrParams; SmallVector attrStmts; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { static const char paramFmt[] = "\"Attribute\":${0}"; static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; // Add the type conversion attributes to the op definition and builders. if (isFunctionAttribute(arg.kind)) { assert(arg.defaultFn); std::string enumName = convertOperandKindToEnumName(arg.kind); static const char typeFmt[] = "{0}::{1}"; static const char defFmt[] = "DefaultValuedOptionalAttr<{0}, \"{1}\">:${2}"; attrDefs.push_back(llvm::formatv( defFmt, llvm::formatv("{0}Attr", enumName), llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name)); attrParams.push_back(llvm::formatv(paramFmt, arg.name)); attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); } // Add the index attributes to the op definition and builders. if (arg.kind == LinalgOperandDefKind::IndexAttr) { assert(arg.indexAttrMap.has_value()); assert(arg.defaultIndices.has_value()); size_t size = arg.indexAttrMap->affineMap().getNumResults(); assert(arg.defaultIndices->size() == size); static const char typeFmt[] = "RankedI64ElementsAttr<[{0}]>"; static const char defFmt[] = "DefaultValuedOptionalAttr<{0}, \"{ {1} }\">:${2}"; std::string defaultVals; llvm::raw_string_ostream ss(defaultVals); llvm::interleave( *arg.defaultIndices, ss, [&](int64_t val) { ss << "static_cast(" << val << ")"; }, ", "); attrDefs.push_back(llvm::formatv(defFmt, llvm::formatv(typeFmt, size), ss.str(), arg.name)); attrParams.push_back(llvm::formatv(paramFmt, arg.name)); attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); } } if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return arg.kind == LinalgOperandDefKind::IndexAttr; })) { attrMethods = R"( bool hasDynamicIndexingMaps(); LogicalResult verifyIndexingMapRequiredAttributes(); )"; } attrList = ",\n" + llvm::join(attrDefs, ",\n"); attrBuilder = llvm::formatv( structuredOpBuilderFormat, opConfig.metadata->cppClassName, llvm::join(attrParams, ", "), llvm::join(attrStmts, "\n")); } os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, opConfig.metadata->name, interfaceNameList, doc, attrList, attrBuilder, definitionList, attrMethods); return success(); } static LogicalResult generateNamedGenericOpDefns(LinalgOpConfig &opConfig, GenerationContext &genContext) { if (!genContext.shouldGenerateDefns()) return success(); raw_ostream &os = genContext.defns(); StringRef className = opConfig.metadata->cppClassName; // Implementation banner. std::string bannerComment = llvm::formatv("Implementation of {0}", className); os << llvm::formatv(bannerFormat, bannerComment); // Compute the number of scalar and tensor arguments. int64_t numOfArgs = llvm::count_if(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return arg.kind == LinalgOperandDefKind::InputTensor || arg.kind == LinalgOperandDefKind::Scalar || arg.kind == LinalgOperandDefKind::OutputTensor; }); // An operation that accesses only scalars and scalar/rank zero tensors is // rank polymorhpic. We implement rank polymorphism by generating different // indexing maps and iterators that match the rank of the first output tensor. // An operation is rank polymorphic if the iteration domain has rank zero. bool isRankPolymorphic = opConfig.structuredOp->iteratorTypes.empty(); // Generate the iterator_types() method. if (!isRankPolymorphic) { std::string iteratorsStr; llvm::raw_string_ostream ss(iteratorsStr); llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, [&](LinalgIteratorTypeDef it) { switch (it) { case LinalgIteratorTypeDef::parallel: ss << "utils::IteratorType::parallel"; break; case LinalgIteratorTypeDef::reduction: ss << "utils::IteratorType::reduction"; break; } }); ss.flush(); os << llvm::formatv(structuredOpIteratorTypesFormat, className, iteratorsStr); } else { os << llvm::formatv(rankPolyStructuredOpIteratorTypesFormat, className); } // Generating the getIndexingMaps() method. if (auto &staticMaps = opConfig.structuredOp->indexingMaps.staticIndexingMaps) { if (staticMaps->empty()) return emitError(genContext.getLoc()) << "op has no indexing maps"; if (!isRankPolymorphic) { AffineMap firstMap = staticMaps->front().affineMap(); // Symbol bindings. { // For each symbol, generate a declaration for it, either with an // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from // an attribute). // TODO: Possibly lift into a top-level method. static const char structuredOpSymbolBindingsFormat[] = R"FMT( static SmallVector getSymbolBindings({0} self) { MLIRContext *context = self.getContext(); SmallVector exprs; {1} return exprs; } )FMT"; unsigned symbolCount = firstMap.getNumSymbols(); SmallVector symbolBindings; for (unsigned i = 0; i < symbolCount; ++i) { symbolBindings.push_back(llvm::formatv( " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); } // Access an index attribute. Parameters: // {0}: Attribute name // {1}: Symbol position // {2}: Attribute index static const char structuredOpAccessAttrFormat[] = R"FMT( int64_t cst{1} = self.get{0}().getValues()[{2}]; exprs.push_back(getAffineConstantExpr(cst{1}, context)); )FMT"; // Update all symbol bindings mapped to an attribute. for (LinalgOperandDef &arg : opConfig.structuredOp->args) { if (arg.kind != LinalgOperandDefKind::IndexAttr) continue; assert(arg.indexAttrMap); for (auto [idx, result] : llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) { if (auto symbol = dyn_cast(result)) { std::string argName = arg.name; argName[0] = toupper(argName[0]); symbolBindings[symbol.getPosition()] = llvm::formatv(structuredOpAccessAttrFormat, argName, symbol.getPosition(), idx); } } } std::string symbolBindingsStr; llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); symbolBindingsSs.flush(); os << llvm::formatv(structuredOpSymbolBindingsFormat, className, symbolBindingsStr); } // Indexing maps. { unsigned dimCount = firstMap.getNumDims(); // Generate a comma-separated list of dim identifiers to be passed to // bindDims, ensuring tht AffineExpr identifiers are bound in the right // order to the proper AffineDimExpr. // This results in vars in scope like: d0, d1, d2... SmallVector dimIndices; for (unsigned i = 0; i < dimCount; ++i) dimIndices.push_back(i); std::string dimIdentsStr; llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); llvm::interleaveComma(dimIndices, dimIdentsSs, [&](unsigned i) { dimIdentsSs << "d" << i; }); dimIdentsSs.flush(); // Statements to add and simplify each affine map. SmallVector stmts; for (auto &indexingMap : *staticMaps) { // TODO: Assert that dim and symbol count match the first. stmts.push_back( llvm::formatv("maps.push_back({0});", generateCppExpression(indexingMap, "context"))); stmts.push_back(llvm::formatv( "maps.back() = " "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " "symbolBindings, {0}, 0));", dimCount)); } // TODO: This needs to be memoized and/or converted to non-parser based // C++ codegen prior to real use. os << llvm::formatv(structuredOpIndexingMapsFormat, className, dimIdentsStr, interleaveToString(stmts, "\n ")); } } else { os << llvm::formatv(rankPolyStructuredOpIndexingMapsFormat, className); } } else { return emitError(genContext.getLoc()) << "generating code for non static indexing maps not currently " "supported"; } // getNumRegionArgs() { // Generates a getNumRegionArgs() method. Parameters: // {0}: Class name // {1}: Number of region args static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( unsigned {0}::getNumRegionArgs() {{ return {1}; } )FMT"; os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, numOfArgs); } // getLibraryCallName() { // Generates a getLibraryCallName method. Parameters: // {0}: Class name static const char structuredOpGetLibraryCallFormat[] = R"FMT( std::string {0}::getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } )FMT"; os << llvm::formatv(structuredOpGetLibraryCallFormat, className); } // hasDynamicIndexingMaps() and verifyIndexingMapRequiredAttributes() if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return arg.kind == LinalgOperandDefKind::IndexAttr; })) { std::vector attrVerifications; for (LinalgOperandDef &arg : opConfig.structuredOp->args) { if (arg.kind != LinalgOperandDefKind::IndexAttr) continue; assert(arg.indexAttrMap); // Verify index attribute. Paramters: // {0}: Attribute name // {1}: Attribute size static const char attrFmt[] = R"FMT( if (auto attr = op->getAttrOfType("{0}")) {{ if (!attr.getType().getElementType().isInteger(64)) return op->emitError("incorrect element type for index attribute '{0}'"); if (attr.getType().getShape() != ArrayRef{{ {1} }) return op->emitError("incorrect shape for index attribute '{0}'"); } )FMT"; attrVerifications.push_back(llvm::formatv( attrFmt, arg.name, arg.indexAttrMap->affineMap().getNumResults())); } // Generates the verifyIndexingMapRequiredAttributes method. Parameters: // {0}: Class name // {1}: Attribute verification static const char structuredOpVerifyIndexingMapRequiredAttributes[] = R"FMT( bool {0}::hasDynamicIndexingMaps() {{ return true; } LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ Operation *op = getOperation(); {1} return success(); } )FMT"; os << llvm::formatv(structuredOpVerifyIndexingMapRequiredAttributes, className, llvm::join(attrVerifications, "\n")); } // regionBuilder() { // Generates a regionBuilder method. Parameters. // {0}: Class name // {1}: Number of args // {2}: Attributes // {3}: Statements static const char structuredOpRegionBuilderFormat[] = R"FMT( void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block, ArrayRef attrs) {{ assert({1} > 0 && block.getNumArguments() == {1} && "{0} regionBuilder expects {1} (>=0) args"); RegionBuilderHelper helper(b, block); SmallVector yields; {2} {3} helper.yieldOutputs(yields); } )FMT"; auto &args = opConfig.structuredOp->args; auto &assignments = opConfig.structuredOp->assignments; size_t generatedAssignmentCount = 0; int localCounter = 0; SmallVector attrs; SmallVector stmts; for (LinalgOperandDef &arg : args) { if (!isFunctionAttribute(arg.kind)) continue; // Obtain the type function attribute values. Parameters. // {0}: enum name // {1}: attribute name // {2}: default type function name static const char attrDef[] = R"FMT( {0} {1}Val = {0}::{2}; auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ return attr.getName() == "{1}"; }); if ({1}Iter != attrs.end()) {{ if (auto attr = llvm::dyn_cast<{0}Attr>({1}Iter->getValue())) {1}Val = attr.getValue(); } )FMT"; std::string enumName = convertOperandKindToEnumName(arg.kind); attrs.push_back( llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn)); } for (LinalgOperandDef &arg : args) { if (arg.kind != LinalgOperandDefKind::OutputTensor) continue; // Find the assignment that correlates with the argument. ScalarAssign *assignment = findAssignment(arg.name, assignments); if (!assignment) return emitError(genContext.getLoc()) << "no assignment found for output argument " << arg.name; ++generatedAssignmentCount; // Recursively generate the expression. std::function(ScalarExpression &)> generateExpression = [&](ScalarExpression &expression) -> std::optional { if (expression.arg) { // Argument reference. std::optional argIndex = findTensorDefArgIndex(*expression.arg, args); if (!argIndex) { emitError(genContext.getLoc()) << "scalar argument not defined on the op: " << *expression.arg; return std::nullopt; } return std::string( llvm::formatv("block.getArgument({0})", *argIndex)); } if (expression.constant) { std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( llvm::formatv(R"FMT(Value {0} = helper.constant("{1}");)FMT", cppIdent, expression.constant)); return cppIdent; } if (expression.index) { // Access an iteration index. std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv("Value {0} = helper.index({1});", cppIdent, *expression.index)); return cppIdent; } if (expression.scalarFn) { std::string enumName = convertFunctionKindToEnumName(expression.scalarFn->kind); // Get the function or attribute name. assert(expression.scalarFn->fnName || expression.scalarFn->attrName); std::string funcType; if (expression.scalarFn->fnName) { funcType = llvm::formatv("{0}::{1}", enumName, *expression.scalarFn->fnName); } if (expression.scalarFn->attrName) { if (llvm::none_of(args, [&](LinalgOperandDef &arg) { return isFunctionAttribute(arg.kind) && arg.name == *expression.scalarFn->attrName; })) { emitError(genContext.getLoc()) << "missing function attribute " << *expression.scalarFn->attrName; } funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName); } assert(!funcType.empty()); // Add the optional type parameter to the operands. SmallVector operandCppValues; if (expression.scalarFn->kind == ScalarFnKind::Type) { assert(expression.scalarFn->typeVar.has_value()); std::optional typeCppValue = findTypeValue(*expression.scalarFn->typeVar, args); if (!typeCppValue) { emitError(genContext.getLoc()) << "type variable " << *expression.scalarFn->typeVar << ", used in a type conversion, must map to a predefined or " << "an argument type but it does not"; return std::nullopt; } operandCppValues.push_back(*typeCppValue); } // Collect the scalar operands. for (ScalarExpression &operand : expression.scalarFn->operands) { auto operandCppValue = generateExpression(operand); if (!operandCppValue) return std::nullopt; operandCppValues.push_back(*operandCppValue); } // Call the function builder. std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv( "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, funcType, interleaveToString(operandCppValues, ", "))); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type"; return std::nullopt; }; std::optional cppValue = generateExpression(assignment->value); if (!cppValue) return failure(); stmts.push_back(llvm::formatv("yields.push_back({0});", *cppValue)); } if (generatedAssignmentCount != assignments.size()) return emitError(genContext.getLoc()) << "mismatched number of assignments vs output arguments"; os << llvm::formatv(structuredOpRegionBuilderFormat, className, numOfArgs, interleaveToString(attrs, "\n "), interleaveToString(stmts, "\n ")); } // Parser and printer. os << llvm::formatv(structuredOpParserFormat, className); // Canonicalizers and folders. os << llvm::formatv(structuredOpFoldersFormat, className); return success(); } static LogicalResult generateOp(LinalgOpConfig &opConfig, GenerationContext &genContext) { // Switch on op type being generated. if (opConfig.structuredOp) { return success( succeeded(generateNamedGenericOpOds(opConfig, genContext)) && succeeded(generateNamedGenericOpDefns(opConfig, genContext))); } return emitError(genContext.getLoc()) << "unsupported operation type"; } //===----------------------------------------------------------------------===// // Command line options and main //===----------------------------------------------------------------------===// static llvm::cl::opt inputFilename(llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); static llvm::cl::opt outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), llvm::cl::value_desc("filename"), llvm::cl::init("")); static llvm::cl::opt outputCppImplFilename("o-impl", llvm::cl::desc("C++ implementation file name"), llvm::cl::value_desc("filename"), llvm::cl::init("")); int main(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML"); // Set up the input file. std::string errorMessage; std::unique_ptr file = mlir::openInputFile(inputFilename, &errorMessage); if (!file) { llvm::errs() << errorMessage << "\n"; return 1; } MLIRContext mlirContext; LinalgYAMLContext yamlContext{&mlirContext}; std::vector opConfigs; // Parse input. Input yin(file->getBuffer(), &yamlContext); yin >> opConfigs; if (yin.error()) return 1; // Open output files. std::unique_ptr outputOdsDecl; if (!outputOdsDeclFilename.empty()) { outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage); if (!outputOdsDecl) { llvm::errs() << errorMessage << "\n"; return 1; } } std::unique_ptr outputCppImpl; if (!outputCppImplFilename.empty()) { outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage); if (!outputCppImpl) { llvm::errs() << errorMessage << "\n"; return 1; } } if (!outputOdsDecl && !outputCppImpl) { llvm::errs() << "error: No output files specified\n"; return 1; } // Generate. GenerationContext genContext(&mlirContext, outputOdsDecl ? &outputOdsDecl->os() : nullptr, outputCppImpl ? &outputCppImpl->os() : nullptr); for (auto &opConfig : opConfigs) { if (!opConfig.metadata) { emitError(genContext.getLoc()) << "missing operation metadata on subsequent op"; return 1; } genContext.setLoc(NameLoc::get( StringAttr::get(&mlirContext, opConfig.metadata->cppClassName))); if (failed(generateOp(opConfig, genContext))) { return 1; } } if (outputOdsDecl) outputOdsDecl->keep(); if (outputCppImpl) outputCppImpl->keep(); return 0; }