aboutsummaryrefslogtreecommitdiff
path: root/mlir/include/mlir/Pass/PassOptions.h
blob: 6717a3585d12a5586f172b9650426d984a5e28eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
//===- PassOptions.h - Pass Option Utilities --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains utilities for registering options with compiler passes and
// pipelines.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_PASS_PASSOPTIONS_H_
#define MLIR_PASS_PASSOPTIONS_H_

#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
#include <memory>

namespace mlir {
class OpPassManager;

namespace detail {
namespace pass_options {
/// Parse a string containing a list of comma-delimited elements, invoking the
/// given parser for each sub-element and passing them to the provided
/// element-append functor.
LogicalResult
parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
                        StringRef optionStr,
                        function_ref<LogicalResult(StringRef)> elementParseFn);
template <typename ElementParser, typename ElementAppendFn>
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
                                      StringRef optionStr,
                                      ElementParser &elementParser,
                                      ElementAppendFn &&appendFn) {
  return parseCommaSeparatedList(
      opt, argName, optionStr, [&](StringRef valueStr) {
        typename ElementParser::parser_data_type value = {};
        if (elementParser.parse(opt, argName, valueStr, value))
          return failure();
        appendFn(value);
        return success();
      });
}

/// Trait used to detect if a type has a operator<< method.
template <typename T>
using has_stream_operator_trait =
    decltype(std::declval<raw_ostream &>() << std::declval<T>());
template <typename T>
using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;

/// Utility methods for printing option values.
template <typename ParserT>
static void printOptionValue(raw_ostream &os, const bool &value) {
  os << (value ? StringRef("true") : StringRef("false"));
}
template <typename ParserT, typename DataT>
static std::enable_if_t<has_stream_operator<DataT>::value>
printOptionValue(raw_ostream &os, const DataT &value) {
  os << value;
}
template <typename ParserT, typename DataT>
static std::enable_if_t<!has_stream_operator<DataT>::value>
printOptionValue(raw_ostream &os, const DataT &value) {
  // If the value can't be streamed, fallback to checking for a print in the
  // parser.
  ParserT::print(os, value);
}
} // namespace pass_options

/// Base container class and manager for all pass options.
class PassOptions : protected llvm::cl::SubCommand {
private:
  /// This is the type-erased option base class. This provides some additional
  /// hooks into the options that are not available via llvm::cl::Option.
  class OptionBase {
  public:
    virtual ~OptionBase() = default;

    /// Out of line virtual function to provide home for the class.
    virtual void anchor();

    /// Print the name and value of this option to the given stream.
    virtual void print(raw_ostream &os) = 0;

    /// Return the argument string of this option.
    StringRef getArgStr() const { return getOption()->ArgStr; }

    /// Returns true if this option has any value assigned to it.
    bool hasValue() const { return optHasValue; }

  protected:
    /// Return the main option instance.
    virtual const llvm::cl::Option *getOption() const = 0;

    /// Copy the value from the given option into this one.
    virtual void copyValueFrom(const OptionBase &other) = 0;

    /// Flag indicating if this option has a value.
    bool optHasValue = false;

    /// Allow access to private methods.
    friend PassOptions;
  };

  /// This is the parser that is used by pass options that use literal options.
  /// This is a thin wrapper around the llvm::cl::parser, that exposes some
  /// additional methods.
  template <typename DataType>
  struct GenericOptionParser : public llvm::cl::parser<DataType> {
    using llvm::cl::parser<DataType>::parser;

    /// Returns an argument name that maps to the specified value.
    std::optional<StringRef> findArgStrForValue(const DataType &value) {
      for (auto &it : this->Values)
        if (it.V.compare(value))
          return it.Name;
      return std::nullopt;
    }
  };

  /// Utility methods for printing option values.
  template <typename DataT>
  static void printValue(raw_ostream &os, GenericOptionParser<DataT> &parser,
                         const DataT &value) {
    if (std::optional<StringRef> argStr = parser.findArgStrForValue(value))
      os << *argStr;
    else
      llvm_unreachable("unknown data value for option");
  }
  template <typename DataT, typename ParserT>
  static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
    detail::pass_options::printOptionValue<ParserT>(os, value);
  }

public:
  /// The specific parser to use depending on llvm::cl parser used. This is only
  /// necessary because we need to provide additional methods for certain data
  /// type parsers.
  /// TODO: We should upstream the methods in GenericOptionParser to avoid the
  /// need to do this.
  template <typename DataType>
  using OptionParser =
      std::conditional_t<std::is_base_of<llvm::cl::generic_parser_base,
                                         llvm::cl::parser<DataType>>::value,
                         GenericOptionParser<DataType>,
                         llvm::cl::parser<DataType>>;

  /// This class represents a specific pass option, with a provided data type.
  template <typename DataType, typename OptionParser = OptionParser<DataType>>
  class Option
      : public llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>,
        public OptionBase {
  public:
    template <typename... Args>
    Option(PassOptions &parent, StringRef arg, Args &&...args)
        : llvm::cl::opt<DataType, /*ExternalStorage=*/false, OptionParser>(
              arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
      assert(!this->isPositional() && !this->isSink() &&
             "sink and positional options are not supported");
      parent.options.push_back(this);

      // Set a callback to track if this option has a value.
      this->setCallback([this](const auto &) { this->optHasValue = true; });
    }
    ~Option() override = default;
    using llvm::cl::opt<DataType, /*ExternalStorage=*/false,
                        OptionParser>::operator=;
    Option &operator=(const Option &other) {
      *this = other.getValue();
      return *this;
    }

  private:
    /// Return the main option instance.
    const llvm::cl::Option *getOption() const final { return this; }

    /// Print the name and value of this option to the given stream.
    void print(raw_ostream &os) final {
      os << this->ArgStr << '=';
      printValue(os, this->getParser(), this->getValue());
    }

    /// Copy the value from the given option into this one.
    void copyValueFrom(const OptionBase &other) final {
      this->setValue(static_cast<const Option<DataType, OptionParser> &>(other)
                         .getValue());
      optHasValue = other.optHasValue;
    }
  };

  /// This class represents a specific pass option that contains a list of
  /// values of the provided data type. The elements within the textual form of
  /// this option are parsed assuming they are comma-separated. Delimited
  /// sub-ranges within individual elements of the list may contain commas that
  /// are not treated as separators for the top-level list.
  template <typename DataType, typename OptionParser = OptionParser<DataType>>
  class ListOption
      : public llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>,
        public OptionBase {
  public:
    template <typename... Args>
    ListOption(PassOptions &parent, StringRef arg, Args &&...args)
        : llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>(
              arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
          elementParser(*this) {
      assert(!this->isPositional() && !this->isSink() &&
             "sink and positional options are not supported");
      assert(!(this->getMiscFlags() & llvm::cl::MiscFlags::CommaSeparated) &&
             "ListOption is implicitly comma separated, specifying "
             "CommaSeparated is extraneous");
      parent.options.push_back(this);
      elementParser.initialize();
    }
    ~ListOption() override = default;
    ListOption<DataType, OptionParser> &
    operator=(const ListOption<DataType, OptionParser> &other) {
      *this = ArrayRef<DataType>(other);
      this->optHasValue = other.optHasValue;
      return *this;
    }

    bool handleOccurrence(unsigned pos, StringRef argName,
                          StringRef arg) override {
      if (this->isDefaultAssigned()) {
        this->clear();
        this->overwriteDefault();
      }
      this->optHasValue = true;
      return failed(detail::pass_options::parseCommaSeparatedList(
          *this, argName, arg, elementParser,
          [&](const DataType &value) { this->addValue(value); }));
    }

    /// Allow assigning from an ArrayRef.
    ListOption<DataType, OptionParser> &operator=(ArrayRef<DataType> values) {
      ((std::vector<DataType> &)*this).assign(values.begin(), values.end());
      optHasValue = true;
      return *this;
    }

    /// Allow accessing the data held by this option.
    MutableArrayRef<DataType> operator*() {
      return static_cast<std::vector<DataType> &>(*this);
    }
    ArrayRef<DataType> operator*() const {
      return static_cast<const std::vector<DataType> &>(*this);
    }

  private:
    /// Return the main option instance.
    const llvm::cl::Option *getOption() const final { return this; }

    /// Print the name and value of this option to the given stream.
    void print(raw_ostream &os) final {
      // Don't print the list if empty. An empty option value can be treated as
      // an element of the list in certain cases (e.g. ListOption<std::string>).
      if ((**this).empty())
        return;

      os << this->ArgStr << '=';
      auto printElementFn = [&](const DataType &value) {
        printValue(os, this->getParser(), value);
      };
      llvm::interleave(*this, os, printElementFn, ",");
    }

    /// Copy the value from the given option into this one.
    void copyValueFrom(const OptionBase &other) final {
      *this = static_cast<const ListOption<DataType, OptionParser> &>(other);
    }

    /// The parser to use for parsing the list elements.
    OptionParser elementParser;
  };

  PassOptions() = default;
  /// Delete the copy constructor to avoid copying the internal options map.
  PassOptions(const PassOptions &) = delete;
  PassOptions(PassOptions &&) = delete;

  /// Copy the option values from 'other' into 'this', where 'other' has the
  /// same options as 'this'.
  void copyOptionValuesFrom(const PassOptions &other);

  /// Parse options out as key=value pairs that can then be handed off to the
  /// `llvm::cl` command line passing infrastructure. Everything is space
  /// separated.
  LogicalResult parseFromString(StringRef options);

  /// Print the options held by this struct in a form that can be parsed via
  /// 'parseFromString'.
  void print(raw_ostream &os);

  /// Print the help string for the options held by this struct. `descIndent` is
  /// the indent that the descriptions should be aligned.
  void printHelp(size_t indent, size_t descIndent) const;

  /// Return the maximum width required when printing the help string.
  size_t getOptionWidth() const;

private:
  /// A list of all of the opaque options.
  std::vector<OptionBase *> options;
};
} // namespace detail

//===----------------------------------------------------------------------===//
// PassPipelineOptions
//===----------------------------------------------------------------------===//

/// Subclasses of PassPipelineOptions provide a set of options that can be used
/// to initialize a pass pipeline. See PassPipelineRegistration for usage
/// details.
///
/// Usage:
///
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
///   ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
/// };
template <typename T>
class PassPipelineOptions : public detail::PassOptions {
public:
  /// Factory that parses the provided options and returns a unique_ptr to the
  /// struct.
  static std::unique_ptr<T> createFromString(StringRef options) {
    auto result = std::make_unique<T>();
    if (failed(result->parseFromString(options)))
      return nullptr;
    return result;
  }
};

/// A default empty option struct to be used for passes that do not need to take
/// any options.
struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
};
} // namespace mlir

//===----------------------------------------------------------------------===//
// MLIR Options
//===----------------------------------------------------------------------===//

namespace llvm {
namespace cl {
//===----------------------------------------------------------------------===//
// std::vector+SmallVector

namespace detail {
template <typename VectorT, typename ElementT>
class VectorParserBase : public basic_parser_impl {
public:
  VectorParserBase(Option &opt) : basic_parser_impl(opt), elementParser(opt) {}

  using parser_data_type = VectorT;

  bool parse(Option &opt, StringRef argName, StringRef arg,
             parser_data_type &vector) {
    if (!arg.consume_front("[") || !arg.consume_back("]")) {
      return opt.error("expected vector option to be wrapped with '[]'",
                       argName);
    }

    return failed(mlir::detail::pass_options::parseCommaSeparatedList(
        opt, argName, arg, elementParser,
        [&](const ElementT &value) { vector.push_back(value); }));
  }

  static void print(raw_ostream &os, const VectorT &vector) {
    llvm::interleave(
        vector, os,
        [&](const ElementT &value) {
          mlir::detail::pass_options::printOptionValue<
              llvm::cl::parser<ElementT>>(os, value);
        },
        ",");
  }

  void printOptionInfo(const Option &opt, size_t globalWidth) const {
    // Add the `vector<>` qualifier to the option info.
    outs() << "  --" << opt.ArgStr;
    outs() << "=<vector<" << elementParser.getValueName() << ">>";
    Option::printHelpStr(opt.HelpStr, globalWidth, getOptionWidth(opt));
  }

  size_t getOptionWidth(const Option &opt) const {
    // Add the `vector<>` qualifier to the option width.
    StringRef vectorExt("vector<>");
    return elementParser.getOptionWidth(opt) + vectorExt.size();
  }

private:
  llvm::cl::parser<ElementT> elementParser;
};
} // namespace detail

template <typename T>
class parser<std::vector<T>>
    : public detail::VectorParserBase<std::vector<T>, T> {
public:
  parser(Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
};
template <typename T, unsigned N>
class parser<SmallVector<T, N>>
    : public detail::VectorParserBase<SmallVector<T, N>, T> {
public:
  parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
};

//===----------------------------------------------------------------------===//
// OpPassManager: OptionValue

template <>
struct OptionValue<mlir::OpPassManager> final : GenericOptionValue {
  using WrapperType = mlir::OpPassManager;

  OptionValue();
  OptionValue(const OptionValue<mlir::OpPassManager> &rhs);
  OptionValue(const mlir::OpPassManager &value);
  OptionValue<mlir::OpPassManager> &operator=(const mlir::OpPassManager &rhs);
  ~OptionValue();

  /// Returns if the current option has a value.
  bool hasValue() const { return value.get(); }

  /// Returns the current value of the option.
  mlir::OpPassManager &getValue() const {
    assert(hasValue() && "invalid option value");
    return *value;
  }

  /// Set the value of the option.
  void setValue(const mlir::OpPassManager &newValue);
  void setValue(StringRef pipelineStr);

  /// Compare the option with the provided value.
  bool compare(const mlir::OpPassManager &rhs) const;
  bool compare(const GenericOptionValue &rhs) const override {
    const auto &rhsOV =
        static_cast<const OptionValue<mlir::OpPassManager> &>(rhs);
    if (!rhsOV.hasValue())
      return false;
    return compare(rhsOV.getValue());
  }

private:
  void anchor() override;

  /// The underlying pass manager. We use a unique_ptr to avoid the need for the
  /// full type definition.
  std::unique_ptr<mlir::OpPassManager> value;
};

//===----------------------------------------------------------------------===//
// OpPassManager: Parser

extern template class basic_parser<mlir::OpPassManager>;

template <>
class parser<mlir::OpPassManager> : public basic_parser<mlir::OpPassManager> {
public:
  /// A utility struct used when parsing a pass manager that prevents the need
  /// for a default constructor on OpPassManager.
  struct ParsedPassManager {
    ParsedPassManager();
    ParsedPassManager(ParsedPassManager &&);
    ~ParsedPassManager();
    operator const mlir::OpPassManager &() const {
      assert(value && "parsed value was invalid");
      return *value;
    }

    std::unique_ptr<mlir::OpPassManager> value;
  };
  using parser_data_type = ParsedPassManager;
  using OptVal = OptionValue<mlir::OpPassManager>;

  parser(Option &opt) : basic_parser(opt) {}

  bool parse(Option &, StringRef, StringRef arg, ParsedPassManager &value);

  /// Print an instance of the underling option value to the given stream.
  static void print(raw_ostream &os, const mlir::OpPassManager &value);

  // Overload in subclass to provide a better default value.
  StringRef getValueName() const override { return "pass-manager"; }

  void printOptionDiff(const Option &opt, mlir::OpPassManager &pm,
                       const OptVal &defaultValue, size_t globalWidth) const;

  // An out-of-line virtual method to provide a 'home' for this class.
  void anchor() override;
};

} // namespace cl
} // namespace llvm

#endif // MLIR_PASS_PASSOPTIONS_H_