#!/usr/bin/python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#

#
# a tool for post-processing Valgrind output from the unit tests
# Use:
#    1) configure the build to use valgrind and output xml
#       $ cmake .. -DUSE_VALGRIND=Yes -DVALGRIND_XML=Yes
#    2) build and run the unit tests
#       $ make && make test
#    3) run grinder from your build directory.  It will look for valgrind xml
#       files named "valgrind-*.xml in the current directory and all
#       subdirectories and process them. Output is sent to stdout
#       $ ../bin/grinder
#
# Note: be sure to clean the build directory before running the unit tests
# to remove old valgrind-*.xml files
#

import logging
import os
import re
import sys
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import ParseError


class Frame:
    """
    Represents info for a single stack frame
    """
    FIELDS = ["fn", "dir", "file", "line"]
    def __init__(self, frame):
        self._fields = dict()
        for tag in self.FIELDS:
            _ = frame.find(tag)
            self._fields[tag] = _.text if _ is not None else "<none>"

    def __str__(self):
        return ("(%s) %s/%s:%s" %
                (self._fields['fn'],
                 self._fields['dir'],
                 self._fields['file'],
                 self._fields['line']))

    def __hash__(self):
        return hash(self.__str__())


class ErrorBase:
    """
    Base class representing a single valgrind error
    """
    def __init__(self, kind):
        self.kind = kind
        self.count = 1

    def __hash__(self):
        return hash(self.kind)

    def __str__(self):
        return "kind = %s  (count=%d)" % (self.kind, self.count)

    def merge(self, other):
        self.count += other.count

    def __lt__(self, other):
        return self.count < other.count
    def __le__(self, other):
        return self.count <= other.count
    def __eq__(self, other):
        return self.count == other.count
    def __gt__(self, other):
        return self.count > other.count
    def __ge__(self, other):
        return self.count >= other.count


class GeneralError(ErrorBase):
    """
    For simple single stack errors
    """
    def __init__(self, error_xml):
        kind = error_xml.find("kind").text
        super(GeneralError, self).__init__(kind)
        w = error_xml.find("what")
        self._what = w.text if w is not None else "<none>"

        # stack
        self._stack = list()
        s = error_xml.find("stack")
        for frame in s.findall("frame"):
            self._stack.append(Frame(frame))

    def __hash__(self):
        h = super(GeneralError, self).__hash__()
        for f in self._stack:
            h += hash(f)
        return h

    def __str__(self):
        s = super(GeneralError, self).__str__() + "\n"
        if self._what:
            s += self._what + "\n"
        s += "Stack:"
        for frame in self._stack:
            s += "\n  %s" % str(frame)
        return s


class LeakError(ErrorBase):
    def __init__(self, error_xml):
        kind = error_xml.find("kind").text
        assert(kind.startswith("Leak_"))
        super(LeakError, self).__init__(kind)
        self._leaked_bytes = 0
        self._leaked_blocks = 0
        self._stack = list()

        # xwhat:
        #    leakedbytes
        #    leakedblocks
        lb = error_xml.find("xwhat/leakedbytes")
        if lb is not None:
            self._leaked_bytes = int(lb.text)
        lb = error_xml.find("xwhat/leakedblocks")
        if lb is not None:
            self._leaked_blocks = int(lb.text)

        # stack
        s = error_xml.find("stack")
        for frame in s.findall("frame"):
            self._stack.append(Frame(frame))

    def merge(self, other):
        super(LeakError, self).merge(other)
        self._leaked_bytes += other._leaked_bytes
        self._leaked_blocks += other._leaked_blocks

    def __hash__(self):
        h = super(LeakError, self).__hash__()
        for f in self._stack:
            h += hash(f)
        return h

    def __str__(self):
        s = super(LeakError, self).__str__() + "\n"
        s += "Leaked Bytes = %d Blocks = %d\n" % (self._leaked_bytes,
                                                  self._leaked_blocks)
        s += "Stack:"
        for frame in self._stack:
            s += "\n  %s" % str(frame)
        return s


class InvalidMemError(ErrorBase):
    def __init__(self, error_xml):
        kind = error_xml.find("kind").text
        super(InvalidMemError, self).__init__(kind)
        # expect
        #  what
        #  stack  (invalid access)
        #  followed by zero or more:
        #      aux what  (aux stack description)
        #      aux stack  (where alloced, freed)
        self._what = "<none>"
        self._stack = None
        self._auxwhat = list()
        self._aux_stacks = list()
        for child in error_xml:
            if child.tag == "what":
                self._what = child.text
            if child.tag == "auxwhat":
                self._auxwhat.append(child.text)
            if child.tag == "stack":
                stack = list()
                for frame in child.findall("frame"):
                    stack.append(Frame(frame))
                if self._stack == None:
                    self._stack = stack
                else:
                    self._aux_stacks.append(stack)

    def __hash__(self):
        # for now don't include what/auxwhat as it may
        # be different for the same codepath
        h = super(InvalidMemError, self).__hash__()
        for f in self._stack:
            h += hash(f)
        for s in self._aux_stacks:
            for f in s:
                h += hash(f)
        return h

    def __str__(self):
        s = super(InvalidMemError, self).__str__() + "\n"
        s += "%s\n" % self._what
        s += "Stack:"
        for frame in self._stack:
            s += "\n  %s" % str(frame)

        for what, stack in zip(self._auxwhat, self._aux_stacks):
            s += "\n%s:" % what
            for frame in stack:
                s += "\n  %s" % str(frame)
        return s


class SignalError(ErrorBase):
    def __init__(self, error_xml):
        super(SignalError, self).__init__("FatalSignal")
        # expects:
        #  signo
        #  signame
        #  stack
        self._signo = "<none>"
        sn = error_xml.find("signo")
        if sn is not None:
            self._signo = sn.text
        self._signame = "<none>"
        sn = error_xml.find("signame")
        if sn is not None:
            self._signame = sn.text

        self._stack = list()
        s = error_xml.find("stack")
        for frame in s.findall("frame"):
            self._stack.append(Frame(frame))

    def __hash__(self):
        # for now don't include what/auxwhat as it may
        # be different for the same codepath
        h = super(SignalError, self).__hash__()
        h += hash(self._signo)
        h += hash(self._signame)
        for f in self._stack:
            h += hash(f)
        return h

    def __str__(self):
        s = super(SignalError, self).__str__() + "\n"
        s += "Signal %s (%s)\n" % (self._signo, self._signame)
        s += "Stack:"
        for frame in self._stack:
            s += "\n  %s" % str(frame)
        return s


_ERROR_CLASSES = {
    'InvalidRead':         InvalidMemError,
    'InvalidWrite':        InvalidMemError,
    'Leak_DefinitelyLost': LeakError,
    'Leak_IndirectlyLost': LeakError,
    'Leak_PossiblyLost':   LeakError,
    'Leak_StillReachable': LeakError,
    'UninitCondition':     GeneralError,
    'SyscallParam':        GeneralError,
    'InvalidFree':         InvalidMemError,
    'FishyValue':          InvalidMemError,
    # TBD:
    'InvalidJump': None,
    'UninitValue': None,
}


def parse_error(error_xml):
    """
    Factory that returns an Error instance
    """
    kind = error_xml.find("kind").text
    e_cls = _ERROR_CLASSES.get(kind)
    if e_cls:
        return e_cls(error_xml)
    raise Exception("Unsupported error type %s, please update grinder"
                    " to handle it" % kind)


def parse_xml_file(filename, exe_name='qdrouterd'):
    """
    Parse out errors from a valgrind output xml file
    """
    logging.debug("Parsing %s", filename)
    error_list = list()
    try:
        root = ET.parse(filename).getroot()
    except ParseError as exc:
        if "no element found" not in str(exc):
            logging.warning("Error parsing %s: %s - skipping",
                            filename, str(exc))
        else:
            logging.debug("No errors found in: %s - skipping",
                          filename)
        return error_list

    pv = root.find('protocolversion')
    if pv is None or not "4" == pv.text:
        # unsupported xml format version
        logging.warning("Unsupported format version for %s, skipping...",
                      filename)
        return error_list

    pt = root.find('protocoltool')
    if pt is None or not "memcheck" == pt.text:
        logging.warning("Not a memcheck file %s, skipping...",
                        filename)
        return error_list

    if not exe_name in root.find('args/argv/exe').text:
        # not from the target executable, skip
        logging.debug("file %s is not generated from %s, skipping...",
                      filename, exe_name)
        return error_list

    for error in root.findall('error'):
        error_list.append(parse_error(error))

    # sigabort, etc classified as fatal_signal
    for signal in root.findall("fatal_signal"):
        error_list.append(SignalError(signal))
    return error_list


def main():
    errors_map = dict()
    file_name = re.compile("valgrind-[0-9]+\.xml")
    for dp, dn, fn in os.walk("."):
        for name in fn:
            if file_name.match(name):
                errors = parse_xml_file(os.path.join(dp, name))
                for e in errors:
                    h = hash(e)
                    if h in errors_map:
                        # coalesce duplicate errors
                        errors_map[h].merge(e)
                    else:
                        errors_map[h] = e

    # sort by # of occurances
    error_list = sorted([e for e in errors_map.values()], reverse=True)

    if error_list:
        for e in error_list:
            print("\n-----")
            print("%s" % str(e))
        print("\n\n-----")
        print("----- %s total issues detected" % len(error_list))
        print("-----")
    else:
        print("No Valgrind errors found! Congratulations ;)")


if __name__ == "__main__":
    sys.exit(main())
