#!/usr/bin/env python3
# -*- Mode: Python -*-
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
############################################################################
#
#  tagdiff -- small utility to compare numerical results of calculations
#
############################################################################
#
#  (The tagged data is assumed to be represented in the form as provided by
#  the taggedout module of the DFTB project. See the appropriate source code
#  for the details. For the format of the config file see the provided
#  commented sample config file.)
#
############################################################################
from __future__ import print_function
import sys
import os.path

REQUIRED_VERSION = 0x2060000
if sys.hexversion < REQUIRED_VERSION:
    sys.stderr.write("This script requires Python 2.6 or higher\n")
    sys.exit(1)

import re
import gzip
from optparse import OptionParser
from tagreader import *
from uncommlines import UncommLines


VERSION = 0.2
DESCRIPTION ="""Compares two tagged output files according the tolerances
in config file(s).
"""

RES_OK = 0
RES_SKIPPED = 1
RES_FAILED = 2
RES_ERROR = 3

EXIT_OK = 0
EXIT_ERROR = 1

DEFAULT_CONFIG = "tagdiff.conf"

############################################################################
# Exceptions
############################################################################

class DiffError(Exception):
    """Raised, if building difference fails for some reson"""
    pass


############################################################################
# Difference builder functors
############################################################################

class Diff(object):

    def __call__(self, orig, new):
        return None

    def __str__(self):
        return "general"



class DiffElement(Diff):
    """Calculates the difference as the maximal difference of the elements"""

    def __call__(self, orig, new):
        try:
            diff = max(map(lambda x,y: abs(x-y), orig.value, new.value))
        except Exception as ee:
            raise DiffError("Exception (%s) while building difference"
                            % str(ee))
        return diff

    def __str__(self):
        return "element"



class DiffVector(Diff):
    """Calculates the difference as the max. difference in vector norm."""

    def __init__(self, nElement):
        self.__nElement = nElement


    def __call__(self, orig, new):

        if self.__nElement == -1:
            nElement = orig.shape[0]
        else:
            nElement = self.__nElement

        if (len(orig.value) != len(new.value)
            or len(orig.value) % nElement != 0):
            raise DiffError("Invalid nr. of elements")
        diff2 = []
        try:
            for ii in range(0, len(orig.value), nElement):
                origvals = orig.value[ii : ii + nElement]
                newvals = new.value[ii : ii + nElement]
                diffs = [abs(x - y)**2 for x, y in zip(origvals, newvals)]
                totaldiff = 0
                for diff in diffs:
                    totaldiff += diff
                diff2.append(totaldiff)
            maxdiff = max(diff2)**0.5
        except Error as ee:
            raise DiffError("Exception (%s) while building difference"
                            % str(ee))

        return maxdiff


    def __str__(self):
        return "vector:%d" % self.__nElement



############################################################################
# Tolerance related objects
############################################################################

class ToleranceEntry(object):
    """Contains tolerance related data"""

    # Valid comparison functors with converters to convert strings to the
    # type of their initialization arguments
    __compFuncs = { "element": (DiffElement, ()),
                    "vector":    (DiffVector, (IntConverter(nolist=True),))
    }

    def __init__(self, pattern, value, compFuncName, compFuncArgs, keep):
        """pattern        -- regular expression
             value        -- tolerance for quantities matching pattern
             compFuncName -- comparison function's name
             compFuncArgs -- arguments to the comparison function
             keep         -- should processed entry keept after processing
        """
        # Store representation
        field = ":".join([ compFuncName, ] + list(compFuncArgs))
        self.__str = " @ ".join([ pattern, value, field, keep ])

        # Convert regular expression
        try:
            self.__pattern = re.compile(pattern)
        except re.error:
            raise InvalidEntryException(msg="Invalid regular expression")

        # Convert value
        self.__value = None
        try:
            self.__value = float(value)
            self.__value = int(value)
        except ValueError:
            if self.__value == None:
                raise InvalidEntryException(msg="Invalid tolerance value")

        # Convert conversion method
        failed = True
        msg = ""
        if compFuncName in self.__compFuncs:
            (compFunc, argConverters) = self.__compFuncs[compFuncName]
            if len(argConverters) == len(compFuncArgs):
                try:
                    args = [ argConverters[ii](compFuncArgs[ii])
                                     for ii in range(len(compFuncArgs)) ]
                    self.__compFunc = compFunc(*args)
                    failed = False
                except ConversionError as msg:
                    pass
        else:
            msg = "Invalid function name"
        if failed:
            raise InvalidEntry(msg="Invalid comparison function '%s' (%s)"
                                                 % (compFuncName, msg))

        # Convert keep-flag
        if keep == "keep":
            self.__keep = True
        else:
            self.__keep = False


    def getPattern(self):
        return self.__pattern
    pattern = property(getPattern, None, None, "pattern")


    def getValue(self):
        return self.__value
    value = property(getValue, None, None, "value")


    def getCompFunc(self):
        return self.__compFunc
    compFunc = property(getCompFunc, None, None, "comparison function")


    def getKeep(self):
        return self.__keep
    keep = property(getKeep, None, None, "keep flag")

    def __str__(self):
        return self.__str



############################################################################
# Parsing
############################################################################

class ConfigParser(object):
    """Parses the tagdiff configuration file and returns the contained
    information as a list of ToleranceEntry-s.
    """

    # Valid comparison functors with argument type lists for initialization
    __comparison = { "element": (DiffElement, ()),
                     "vector":    (DiffVector, (IntConverter(nolist=True),))
    }
    # default comparison functor name
    __defCompFunc = "element"
    __defCompFuncArgs = ()

    # default flag
    __defFlag = "nokeep"


    def __init__(self, file):
        """file -- file like (opened) object containing the configuration file
                   (file should remain open until parser doesn't give back all
                   entries)
        """
        self.__file = file


    def iterateEntries(self):
        """Generator for iterating over the entries in the config file"""

        for (line, iLine) in UncommLines(self.__file, returnLineNr=True):
            words = line.split("@")
            if len(words) < 2:
                raise InvalidEntry(iLine+1, iLine+2, "Not enough fields")

            pattern = words[0].strip()
            value = words[1].strip()

            if len(words) < 3 or words[2].strip() == "":
                compFunc = self.__defCompFunc
                compFuncArgs = self.__defCompFuncArgs
            else:
                tokens = [ s.strip() for s in words[2].split(":") ]
                compFunc = tokens[0]
                compFuncArgs = tuple(tokens[1:])

            if len(words) < 4 or words[3].strip() == "":
                flag = self.__defFlag
            else:
                flag = words[3].strip()

            try:
                te = ToleranceEntry(pattern, value, compFunc, compFuncArgs,
                                    flag)
            except InvalidEntry as ee:
                raise InvalidEntry(iLine+1, iLine+2, ee.msg)
            yield te

    entries = property(iterateEntries, None, None,
                                         "Sequence of extracted entries.")


############################################################################
# Input/Output
############################################################################
resultStr = { RES_OK:            "OK",
                            RES_SKIPPED: "Skipped",
                            RES_FAILED:    "Failed",
                            RES_ERROR:     "Error",
                            }

def printResult(name, method, msg, result):
    """Prints a result of a comparison
    name     -- name of the data
    method -- comarison method
    msg        -- msg to print
    result -- result of the comparison
    """
    res = resultStr[result]
    tmp = [ "%-20s %-20s %-27s" % (name, method, msg) ]
    sys.stdout.write("%s %-10s\n" % ("\n".join(tmp), res))



def printError(message):
    sys.stderr.write("ERROR::%s\n" % message)


def zOpen(filename, mode):
    """Opens a file according to its extension with different methods"""
    if len(filename) > 3 and filename[-3:] == ".gz":
        return gzip.open(filename, mode)
    else:
        return open(filename, mode)



############################################################################
# Option processing
############################################################################

def parseOptions():
    """Parses script's options"""

    parser = OptionParser(usage="usage: %prog [ options ] orig new",
                          description=DESCRIPTION,
                          version=("%%prog %s" % VERSION))
    parser.add_option("-c", "--config", dest="configfile", action="append",
                      help="config file to use (multiple config files can be "
                      "specified by using this option multiple times)")
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
                      default=False, help="verbose mode")
    options, args = parser.parse_args()
    if not args or len(args) < 2:
        parser.print_help()
        sys.exit(EXIT_ERROR)
    return (options, args)



############################################################################
# Main program
############################################################################

def main():

    #
    # Parse options
    #
    options, args = parseOptions()
    oldFile, newFile = args[:2]
    if options.configfile:
        confFiles = options.configfile[:]
    else:
        confFiles = [os.path.join(os.path.dirname(sys.argv[0]), DEFAULT_CONFIG)]

    configEntries = []
    for confFile in confFiles:
        if options.verbose:
            print("# Reading config file `%s'" % confFile)
        try:
            f = zOpen(confFile, "r")
            configEntries += [ ce for ce in ConfigParser(f).entries ]
        except InvalidEntry as ee:
            printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                                 % (ee.msg, confFile, ee.start, ee.end))
            f.close()
            return EXIT_ERROR
        except IOError:
            printError("Input/output error for file '%s'" % (confFile,))
            return EXIT_ERROR
        f.close()

    #
    # Read and parse files
    #
    if options.verbose:
        print("# Reading old tagged file `%s'" % oldFile)
    try:
        f = zOpen(oldFile, "r")
        old = TaggedCollection(ResultParser(f).entries)
    except InvalidEntry as ee:
        printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                             % (ee.msg, oldFile, ee.start, ee.end))
        f.close()
        return EXIT_ERROR
    except IOError:
        printError("Input/output error for file '%s'" % (oldFile,))
        return EXIT_ERROR
    f.close()

    if options.verbose:
        print("# Reading new tagged file `%s'" % newFile)
    try:
        f = zOpen(newFile, "r")
        new = TaggedCollection(ResultParser(f).entries)
    except InvalidEntry as ee:
        printError("Invalid entry (%s) in file '%s' between lines %d and %d"
                             % (ee.msg, newFile, ee.start, ee.end))
        f.close()
        return EXIT_ERROR
    except IOError:
        printError("Input/output error for file '%s'" % (newFile,))
        return EXIT_ERROR
    f.close()


    #
    # Compare entries
    #
    for configEntry in configEntries:
        if options.verbose:
            print(" # Processing rule:     `%s'" % configEntry)

        compFunc = configEntry.compFunc
        oldEntries = old.getMatchingEntries(configEntry.pattern)
        for oldEntry in oldEntries:
            name = oldEntry.name
            newEntry = new.getEntry(name)

            if not configEntry.keep:
                old.delEntry(name)
                new.delEntry(name)

            if not newEntry:
                printResult(name, compFunc, "Not found in new", RES_SKIPPED)
                continue
            if not oldEntry.isComparable(newEntry):
                printResult(name, compFunc, "Mismatching data", RES_ERROR)
                continue

            try:
                result = compFunc(oldEntry, newEntry)
            except DiffError as msg:
                printResult(name, compFunc,
                            "Difference building error (%s)" % msg, RES_ERROR)
                continue

            passed = (result <= configEntry.value)
            if passed:
                msg = "%-20s" % (str(result))
                res = RES_OK
            else:
                msg = "%-20s" % (str(result))
                res = RES_FAILED
            printResult(name, compFunc, msg, res)

    return EXIT_OK



############################################################################
# Start
############################################################################

if __name__ == "__main__":

    status = main()
    sys.exit(status)

## Test
#
#confFile = "tagdiff.conf"
#oldFile = "orig.tgo"
#newFile = "new.tgo"
#sys.argv = ["alma", confFile, oldFile, newFile ]
#main()
