#!/usr/bin/python
#
# top-like utility for displaying kvm statistics
#
# Copyright 2006-2008 Qumranet Technologies
# Copyright 2008-2011 Red Hat, Inc.
#
# Authors:
#  Avi Kivity <avi@redhat.com>
#
# This work is licensed under the terms of the GNU GPL, version 2.  See
# the COPYING file in the top-level directory.

import curses
import sys
import os
import time
import optparse
import ctypes
import fcntl
import resource
import struct
import re
from collections import defaultdict

class DebugfsProvider(object):
    def __init__(self):
        self._fields = walkdir(PATH_DEBUGFS_KVM)[2]
    def fields(self):
        return self._fields
    def select(self, fields):
        self._fields = fields
    def read(self):
        def val(key):
            return int(file(PATH_DEBUGFS_KVM + '/' + key).read())
        return dict([(key, val(key)) for key in self._fields])

VMX_EXIT_REASONS = {
    'EXCEPTION_NMI':        0,
    'EXTERNAL_INTERRUPT':   1,
    'TRIPLE_FAULT':         2,
    'PENDING_INTERRUPT':    7,
    'NMI_WINDOW':           8,
    'TASK_SWITCH':          9,
    'CPUID':                10,
    'HLT':                  12,
    'INVLPG':               14,
    'RDPMC':                15,
    'RDTSC':                16,
    'VMCALL':               18,
    'VMCLEAR':              19,
    'VMLAUNCH':             20,
    'VMPTRLD':              21,
    'VMPTRST':              22,
    'VMREAD':               23,
    'VMRESUME':             24,
    'VMWRITE':              25,
    'VMOFF':                26,
    'VMON':                 27,
    'CR_ACCESS':            28,
    'DR_ACCESS':            29,
    'IO_INSTRUCTION':       30,
    'MSR_READ':             31,
    'MSR_WRITE':            32,
    'INVALID_STATE':        33,
    'MWAIT_INSTRUCTION':    36,
    'MONITOR_INSTRUCTION':  39,
    'PAUSE_INSTRUCTION':    40,
    'MCE_DURING_VMENTRY':   41,
    'TPR_BELOW_THRESHOLD':  43,
    'APIC_ACCESS':          44,
    'EPT_VIOLATION':        48,
    'EPT_MISCONFIG':        49,
    'WBINVD':               54,
    'XSETBV':               55,
    'APIC_WRITE':           56,
    'INVPCID':              58,
}

SVM_EXIT_REASONS = {
    'READ_CR0':       0x000,
    'READ_CR3':       0x003,
    'READ_CR4':       0x004,
    'READ_CR8':       0x008,
    'WRITE_CR0':      0x010,
    'WRITE_CR3':      0x013,
    'WRITE_CR4':      0x014,
    'WRITE_CR8':      0x018,
    'READ_DR0':       0x020,
    'READ_DR1':       0x021,
    'READ_DR2':       0x022,
    'READ_DR3':       0x023,
    'READ_DR4':       0x024,
    'READ_DR5':       0x025,
    'READ_DR6':       0x026,
    'READ_DR7':       0x027,
    'WRITE_DR0':      0x030,
    'WRITE_DR1':      0x031,
    'WRITE_DR2':      0x032,
    'WRITE_DR3':      0x033,
    'WRITE_DR4':      0x034,
    'WRITE_DR5':      0x035,
    'WRITE_DR6':      0x036,
    'WRITE_DR7':      0x037,
    'EXCP_BASE':      0x040,
    'INTR':           0x060,
    'NMI':            0x061,
    'SMI':            0x062,
    'INIT':           0x063,
    'VINTR':          0x064,
    'CR0_SEL_WRITE':  0x065,
    'IDTR_READ':      0x066,
    'GDTR_READ':      0x067,
    'LDTR_READ':      0x068,
    'TR_READ':        0x069,
    'IDTR_WRITE':     0x06a,
    'GDTR_WRITE':     0x06b,
    'LDTR_WRITE':     0x06c,
    'TR_WRITE':       0x06d,
    'RDTSC':          0x06e,
    'RDPMC':          0x06f,
    'PUSHF':          0x070,
    'POPF':           0x071,
    'CPUID':          0x072,
    'RSM':            0x073,
    'IRET':           0x074,
    'SWINT':          0x075,
    'INVD':           0x076,
    'PAUSE':          0x077,
    'HLT':            0x078,
    'INVLPG':         0x079,
    'INVLPGA':        0x07a,
    'IOIO':           0x07b,
    'MSR':            0x07c,
    'TASK_SWITCH':    0x07d,
    'FERR_FREEZE':    0x07e,
    'SHUTDOWN':       0x07f,
    'VMRUN':          0x080,
    'VMMCALL':        0x081,
    'VMLOAD':         0x082,
    'VMSAVE':         0x083,
    'STGI':           0x084,
    'CLGI':           0x085,
    'SKINIT':         0x086,
    'RDTSCP':         0x087,
    'ICEBP':          0x088,
    'WBINVD':         0x089,
    'MONITOR':        0x08a,
    'MWAIT':          0x08b,
    'MWAIT_COND':     0x08c,
    'XSETBV':         0x08d,
    'NPF':            0x400,
}

# EC definition of HSR (from arch/arm64/include/asm/kvm_arm.h)
AARCH64_EXIT_REASONS = {
    'UNKNOWN':      0x00,
    'WFI':          0x01,
    'CP15_32':      0x03,
    'CP15_64':      0x04,
    'CP14_MR':      0x05,
    'CP14_LS':      0x06,
    'FP_ASIMD':     0x07,
    'CP10_ID':      0x08,
    'CP14_64':      0x0C,
    'ILL_ISS':      0x0E,
    'SVC32':        0x11,
    'HVC32':        0x12,
    'SMC32':        0x13,
    'SVC64':        0x15,
    'HVC64':        0x16,
    'SMC64':        0x17,
    'SYS64':        0x18,
    'IABT':         0x20,
    'IABT_HYP':     0x21,
    'PC_ALIGN':     0x22,
    'DABT':         0x24,
    'DABT_HYP':     0x25,
    'SP_ALIGN':     0x26,
    'FP_EXC32':     0x28,
    'FP_EXC64':     0x2C,
    'SERROR':       0x2F,
    'BREAKPT':      0x30,
    'BREAKPT_HYP':  0x31,
    'SOFTSTP':      0x32,
    'SOFTSTP_HYP':  0x33,
    'WATCHPT':      0x34,
    'WATCHPT_HYP':  0x35,
    'BKPT32':       0x38,
    'VECTOR32':     0x3A,
    'BRK64':        0x3C,
}

# From include/uapi/linux/kvm.h, KVM_EXIT_xxx
USERSPACE_EXIT_REASONS = {
    'UNKNOWN':          0,
    'EXCEPTION':        1,
    'IO':               2,
    'HYPERCALL':        3,
    'DEBUG':            4,
    'HLT':              5,
    'MMIO':             6,
    'IRQ_WINDOW_OPEN':  7,
    'SHUTDOWN':         8,
    'FAIL_ENTRY':       9,
    'INTR':             10,
    'SET_TPR':          11,
    'TPR_ACCESS':       12,
    'S390_SIEIC':       13,
    'S390_RESET':       14,
    'DCR':              15,
    'NMI':              16,
    'INTERNAL_ERROR':   17,
    'OSI':              18,
    'PAPR_HCALL':       19,
    'S390_UCONTROL':    20,
    'WATCHDOG':         21,
    'S390_TSCH':        22,
    'EPR':              23,
    'SYSTEM_EVENT':     24,
}

X86_EXIT_REASONS = {
    'vmx': VMX_EXIT_REASONS,
    'svm': SVM_EXIT_REASONS,
}

SC_PERF_EVT_OPEN = None
EXIT_REASONS = None

IOCTL_NUMBERS = {
    'SET_FILTER' : 0x40082406,
    'ENABLE'     : 0x00002400,
    'DISABLE'    : 0x00002401,
    'RESET'      : 0x00002403,
}

def x86_init(flag):
    global SC_PERF_EVT_OPEN
    global EXIT_REASONS

    SC_PERF_EVT_OPEN = 298
    EXIT_REASONS = X86_EXIT_REASONS[flag]

def s390_init():
    global SC_PERF_EVT_OPEN

    SC_PERF_EVT_OPEN = 331

def ppc_init():
    global SC_PERF_EVT_OPEN
    global IOCTL_NUMBERS

    SC_PERF_EVT_OPEN = 319

    IOCTL_NUMBERS['ENABLE'] = 0x20002400
    IOCTL_NUMBERS['DISABLE'] = 0x20002401
    IOCTL_NUMBERS['SET_FILTER'] = 0x80002406 | (ctypes.sizeof(ctypes.c_char_p)
                                                << 16)

def aarch64_init():
    global SC_PERF_EVT_OPEN
    global EXIT_REASONS

    SC_PERF_EVT_OPEN = 241
    EXIT_REASONS = AARCH64_EXIT_REASONS

def detect_platform():
    if os.uname()[4].startswith('ppc'):
        ppc_init()
        return
    elif os.uname()[4].startswith('aarch64'):
        aarch64_init()
        return

    for line in file('/proc/cpuinfo').readlines():
        if line.startswith('flags'):
            for flag in line.split():
                if flag in X86_EXIT_REASONS:
                    x86_init(flag)
                    return
        elif line.startswith('vendor_id'):
            for flag in line.split():
                if flag == 'IBM/S390':
                    s390_init()
                    return

detect_platform()


def walkdir(path):
    """Returns os.walk() data for specified directory.

    As it is only a wrapper it returns the same 3-tuple of (dirpath,
    dirnames, filenames).
    """
    return next(os.walk(path))

filters = {}
filters['kvm_userspace_exit'] = ('reason', USERSPACE_EXIT_REASONS)
if EXIT_REASONS:
    filters['kvm_exit'] = ('exit_reason', EXIT_REASONS)

libc = ctypes.CDLL('libc.so.6')
syscall = libc.syscall
get_errno = libc.__errno_location
get_errno.restype = ctypes.POINTER(ctypes.c_int)

class perf_event_attr(ctypes.Structure):
    _fields_ = [('type', ctypes.c_uint32),
                ('size', ctypes.c_uint32),
                ('config', ctypes.c_uint64),
                ('sample_freq', ctypes.c_uint64),
                ('sample_type', ctypes.c_uint64),
                ('read_format', ctypes.c_uint64),
                ('flags', ctypes.c_uint64),
                ('wakeup_events', ctypes.c_uint32),
                ('bp_type', ctypes.c_uint32),
                ('bp_addr', ctypes.c_uint64),
                ('bp_len', ctypes.c_uint64),
                ]
def _perf_event_open(attr, pid, cpu, group_fd, flags):
    return syscall(SC_PERF_EVT_OPEN, ctypes.pointer(attr), ctypes.c_int(pid),
                   ctypes.c_int(cpu), ctypes.c_int(group_fd),
                   ctypes.c_long(flags))

PERF_TYPE_TRACEPOINT = 2
PERF_FORMAT_GROUP = 1 << 3

PATH_DEBUGFS_TRACING = '/sys/kernel/debug/tracing'
PATH_DEBUGFS_KVM = '/sys/kernel/debug/kvm'

class Group(object):
    def __init__(self, cpu):
        self.events = []
        self.group_leader = None
        self.cpu = cpu
    def add_event(self, name, event_set, tracepoint, filter = None):
        self.events.append(Event(group = self,
                                 name = name, event_set = event_set,
                                 tracepoint = tracepoint, filter = filter))
        if len(self.events) == 1:
            self.file = os.fdopen(self.events[0].fd)
    def read(self):
        bytes = 8 * (1 + len(self.events))
        fmt = 'xxxxxxxx' + 'q' * len(self.events)
        return dict(zip([event.name for event in self.events],
                        struct.unpack(fmt, self.file.read(bytes))))

class Event(object):
    def __init__(self, group, name, event_set, tracepoint, filter = None):
        self.name = name
        attr = perf_event_attr()
        attr.type = PERF_TYPE_TRACEPOINT
        attr.size = ctypes.sizeof(attr)
        id_path = os.path.join(PATH_DEBUGFS_TRACING, 'events', event_set,
                               tracepoint, 'id')
        id = int(file(id_path).read())
        attr.config = id
        attr.sample_period = 1
        attr.read_format = PERF_FORMAT_GROUP
        group_leader = -1
        if group.events:
            group_leader = group.events[0].fd
        fd = _perf_event_open(attr, -1, group.cpu, group_leader, 0)
        if fd == -1:
            err = get_errno()[0]
            raise Exception('perf_event_open failed, errno = ' + err.__str__())
        if filter:
            fcntl.ioctl(fd, IOCTL_NUMBERS['SET_FILTER'], filter)
        self.fd = fd
    def enable(self):
        fcntl.ioctl(self.fd, IOCTL_NUMBERS['ENABLE'], 0)
    def disable(self):
        fcntl.ioctl(self.fd, IOCTL_NUMBERS['DISABLE'], 0)
    def reset(self):
        fcntl.ioctl(self.fd, IOCTL_NUMBERS['RESET'], 0)

class TracepointProvider(object):
    def __init__(self):
        path = os.path.join(PATH_DEBUGFS_TRACING, 'events', 'kvm')
        fields = walkdir(path)[1]
        extra = []
        for f in fields:
            if f in filters:
                subfield, values = filters[f]
                for name, number in values.iteritems():
                    extra.append(f + '(' + name + ')')
        fields += extra
        self._setup(fields)
        self.select(fields)
    def fields(self):
        return self._fields

    def _online_cpus(self):
        l = []
        pattern = r'cpu([0-9]+)'
        basedir = '/sys/devices/system/cpu'
        for entry in os.listdir(basedir):
            match = re.match(pattern, entry)
            if not match:
                continue
            path = os.path.join(basedir, entry, 'online')
            if os.path.exists(path) and open(path).read().strip() != '1':
                continue
            l.append(int(match.group(1)))
        return l

    def _setup(self, _fields):
        self._fields = _fields
        cpus = self._online_cpus()
        nfiles = len(cpus) * 1000
        resource.setrlimit(resource.RLIMIT_NOFILE, (nfiles, nfiles))
        events = []
        self.group_leaders = []
        for cpu in cpus:
            group = Group(cpu)
            for name in _fields:
                tracepoint = name
                filter = None
                m = re.match(r'(.*)\((.*)\)', name)
                if m:
                    tracepoint, sub = m.groups()
                    filter = '%s==%d\0' % (filters[tracepoint][0],
                                           filters[tracepoint][1][sub])
                event = group.add_event(name, event_set = 'kvm',
                                        tracepoint = tracepoint,
                                        filter = filter)
            self.group_leaders.append(group)
    def select(self, fields):
        for group in self.group_leaders:
            for event in group.events:
                if event.name in fields:
                    event.reset()
                    event.enable()
                else:
                    event.disable()
    def read(self):
        ret = defaultdict(int)
        for group in self.group_leaders:
            for name, val in group.read().iteritems():
                ret[name] += val
        return ret

class Stats:
    def __init__(self, providers, fields = None):
        self.providers = providers
        self.fields_filter = fields
        self._update()
    def _update(self):
        def wanted(key):
            if not self.fields_filter:
                return True
            return re.match(self.fields_filter, key) is not None
        self.values = dict()
        for d in providers:
            provider_fields = [key for key in d.fields() if wanted(key)]
            for key in provider_fields:
                self.values[key] = None
            d.select(provider_fields)
    def set_fields_filter(self, fields_filter):
        self.fields_filter = fields_filter
        self._update()
    def get(self):
        for d in providers:
            new = d.read()
            for key in d.fields():
                oldval = self.values.get(key, (0, 0))
                newval = new[key]
                newdelta = None
                if oldval is not None:
                    newdelta = newval - oldval[0]
                self.values[key] = (newval, newdelta)
        return self.values

if not os.path.exists('/sys/kernel/debug'):
    sys.stderr.write('Please enable CONFIG_DEBUG_FS in your kernel.')
    sys.exit(1)
if not os.path.exists(PATH_DEBUGFS_KVM):
    sys.stderr.write("Please make sure, that debugfs is mounted and "
                     "readable by the current user:\n"
                     "('mount -t debugfs debugfs /sys/kernel/debug')\n"
                     "Also ensure, that the kvm modules are loaded.\n")
    sys.exit(1)
if not os.path.exists(PATH_DEBUGFS_TRACING):
    sys.stderr.write("Please make {0} readable by the current user.\n"
                     .format(PATH_DEBUGFS_TRACING))
    sys.exit(1)

LABEL_WIDTH = 40
NUMBER_WIDTH = 10

def tui(screen, stats):
    curses.use_default_colors()
    curses.noecho()
    drilldown = False
    fields_filter = stats.fields_filter
    def update_drilldown():
        if not fields_filter:
            if drilldown:
                stats.set_fields_filter(None)
            else:
                stats.set_fields_filter(r'^[^\(]*$')
    update_drilldown()
    def refresh(sleeptime):
        screen.erase()
        screen.addstr(0, 0, 'kvm statistics')
        screen.addstr(2, 1, 'Event')
        screen.addstr(2, 1 + LABEL_WIDTH + NUMBER_WIDTH - len('Total'), 'Total')
        screen.addstr(2, 1 + LABEL_WIDTH + NUMBER_WIDTH + 8 - len('Current'), 'Current')
        row = 3
        s = stats.get()
        def sortkey(x):
            if s[x][1]:
                return (-s[x][1], -s[x][0])
            else:
                return (0, -s[x][0])
        for key in sorted(s.keys(), key = sortkey):
            if row >= screen.getmaxyx()[0]:
                break
            values = s[key]
            if not values[0] and not values[1]:
                break
            col = 1
            screen.addstr(row, col, key)
            col += LABEL_WIDTH
            screen.addstr(row, col, '%10d' % (values[0],))
            col += NUMBER_WIDTH
            if values[1] is not None:
                screen.addstr(row, col, '%8d' % (values[1] / sleeptime,))
            row += 1
        screen.refresh()

    sleeptime = 0.25
    while True:
        refresh(sleeptime)
        curses.halfdelay(int(sleeptime * 10))
        sleeptime = 3
        try:
            c = screen.getkey()
            if c == 'x':
                drilldown = not drilldown
                update_drilldown()
            if c == 'q':
                break
        except KeyboardInterrupt:
            break
        except curses.error:
            continue

def batch(stats):
    s = stats.get()
    time.sleep(1)
    s = stats.get()
    for key in sorted(s.keys()):
        values = s[key]
        print '%-22s%10d%10d' % (key, values[0], values[1])

def log(stats):
    keys = sorted(stats.get().iterkeys())
    def banner():
        for k in keys:
            print '%10s' % k[0:9],
        print
    def statline():
        s = stats.get()
        for k in keys:
            print ' %9d' % s[k][1],
        print
    line = 0
    banner_repeat = 20
    while True:
        time.sleep(1)
        if line % banner_repeat == 0:
            banner()
        statline()
        line += 1

options = optparse.OptionParser()
options.add_option('-1', '--once', '--batch',
                   action = 'store_true',
                   default = False,
                   dest = 'once',
                   help = 'run in batch mode for one second',
                   )
options.add_option('-l', '--log',
                   action = 'store_true',
                   default = False,
                   dest = 'log',
                   help = 'run in logging mode (like vmstat)',
                   )
options.add_option('-t', '--tracepoints',
                   action = 'store_true',
                   default = False,
                   dest = 'tracepoints',
                   help = 'retrieve statistics from tracepoints',
                   )
options.add_option('-d', '--debugfs',
                   action = 'store_true',
                   default = False,
                   dest = 'debugfs',
                   help = 'retrieve statistics from debugfs',
                   )
options.add_option('-f', '--fields',
                   action = 'store',
                   default = None,
                   dest = 'fields',
                   help = 'fields to display (regex)',
                   )
(options, args) = options.parse_args(sys.argv)

providers = []
if options.tracepoints:
    providers.append(TracepointProvider())
if options.debugfs:
    providers.append(DebugfsProvider())

if len(providers) == 0:
    try:
        providers = [TracepointProvider()]
    except:
        providers = [DebugfsProvider()]

stats = Stats(providers, fields = options.fields)

if options.log:
    log(stats)
elif not options.once:
    curses.wrapper(tui, stats)
else:
    batch(stats)