#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
"""
Calculates forces and lattice derivatives with finite differences.
"""

import argparse
from derivtools import readgen, writegen, cart2frac, frac2cart, stress2latderivs, AA__BOHR, exists, create_temporary_copy
import numpy as np
import re
import subprocess
import shutil

DESCRIPTION = """
Calculates the forces using the specified DFTB+ binary by finite differences
displacing the atoms along every axis. The geometry of the configuration must
be specified in a file called 'geo.gen.template'. The DFTB+ input file should
include the geometry from the file 'geo.gen'. The input file must specify the
option for writing an autotest file and if the reference should be calculated
then also the option for calculating the forces.
"""

ENERGY_PATTERN = re.compile(r"mermin_energy[^:]*:[^:]*:[^:]*:\s*(?P<value>\S+)")

FORCES_PATTERN = re.compile(
    r"forces[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)",
    re.MULTILINE)

LATTICE_DERIV_PATTERN = re.compile(
    r"^stress[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)", re.MULTILINE)

REFERENCE_AUTOTEST = 'autotest0.tag'
CHARGES_FILE = "charges.bin"

def main():
    """Main routine"""

    args = parse_arguments()
    directions= directions_from_args(args.directions)
    specienames, species, coords, origin, latvecs = readgen("geo.gen.template")
    atoms = atoms_from_args(args.atoms, coords.shape[0])
    if (exists(CHARGES_FILE)):
        charge_restart = create_temporary_copy(CHARGES_FILE)
    else:
        charge_restart = None
    calcforces = not args.skipforces
    calclatderivs = (latvecs is not None) and not args.skiplattice
    if args.calcref:
        reffile = REFERENCE_AUTOTEST
    elif args.ref is not None:
        reffile = args.ref
    else:
        reffile = None
    disp = args.disp * AA__BOHR
    binary = args.binary
    print("BINARY:", binary)

    if args.calcref:
        calculate_reference(binary, reffile, coords, specienames, species,
                            origin, latvecs, charge_restart)

    if reffile is not None:
        forces0, latderivs0 = read_reference_results(reffile, calcforces,
                                                     calclatderivs, latvecs)
    else:
        forces0 = latderivs0 = None

    if calcforces:
        forces = calculate_forces(binary, disp, coords, specienames,
                                  species, origin, latvecs, charge_restart,
                                  directions, atoms)

    if latvecs is not None and calclatderivs:
        latderivs = calculate_latderivs(binary, disp, coords, latvecs,
                                        specienames, species, origin,
                                        charge_restart, directions)

    if calcforces:
        print_forces(forces, forces0, directions, atoms)

    if latvecs is not None and calclatderivs:
        print_latderivs(latderivs, latderivs0, directions)

    if (charge_restart):
        shutil.copy2(charge_restart.name, CHARGES_FILE)


def parse_arguments():
    """Parses command line arguments"""

    parser = argparse.ArgumentParser(description=DESCRIPTION)

    msg = "Specify the displacement of the atoms (unit: ANGSTROM)"
    parser.add_argument("-d", "--displacement", type=float, dest="disp",
                        default=1e-5, help=msg)

    msg = "Compare derivatives with those in reference autotest file"
    parser.add_argument("-r", "--reference", dest="ref", help=msg)

    msg = "Calculate reference system (and compare derivatives with it),"\
          " resulting autotest file will be saved as 'autotest0.tag'"
    parser.add_argument("-c", "--calc-reference", dest="calcref",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of the lattice derivatives in the case of"\
          " periodic systems"
    parser.add_argument("-L", "--skip-lattice", dest="skiplattice",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of forces (useful when only lattice"\
          " derivatives should be calculated)"
    parser.add_argument("-F", "--skip-forces", dest="skipforces",
                        action="store_true", default=False, help=msg)

    msg = "Directions of calculation (arguments like 'xyz' 'xy', etc.)"\
          "input order not maintained, default: xyz"
    parser.add_argument("--directions", dest="directions", default="xyz",
                        help=msg)
    msg = "Atoms for calculation, starting with one and including both limits,"\
          " use ':' for ranges, ',' for separation and '1:-x' for excluding"\
          " the last x atoms. Output order will be sorted smallest to "\
          "highest (arguments like '1,2' '1:5', '2,4:6,9' and '1,3:4,9:-4')"
    parser.add_argument("-a", "--atoms", dest="atoms", default=None,
                        help=msg)

    msg = "DFTB+ binary"
    parser.add_argument("binary", help=msg)

    args = parser.parse_args()
    if args.ref is not None and args.calcref:
        msg = "Specifying a reference file and requesting a calculation of the"\
              " reference system are mutually exclusive options"
        parser.error(msg)
    return args


def read_reference_results(autotest0, calcforces, calclatderivs, latvecs):
    """Reads in reference results"""

    forces0 = None
    latderivs0 = None
    fp = open(autotest0, "r")
    txt = fp.read()
    fp.close()
    if calcforces:
        match = FORCES_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            forces0 = tmp.reshape((-1, 3))
        else:
            raise ValueError("No forces found in reference file!")
    if calclatderivs:
        match = LATTICE_DERIV_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            stress0 = tmp.reshape((-1, 3))
            latderivs0 = stress2latderivs(stress0, latvecs)
        else:
            raise ValueError("No lattice derivatives found in reference file!")
    return forces0, latderivs0


def calculate_reference(binary, reffile, coords, specienames, species, origin,
                        latvecs, charge_restart):
    """Calculates reference system"""

    writegen("geo.gen", (specienames, species, coords, origin, latvecs))
    if (charge_restart):
        shutil.copy2(charge_restart.name, CHARGES_FILE)
    subprocess.call([binary])
    shutil.move("autotest.tag", reffile)


def calculate_forces(binary, disp, coords, specienames, species, origin,
                     latvecs, charge_restart, directions, atoms):
    """Calculates forces by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    energy = np.empty((2,), dtype=float)
    forces = np.empty((len(atoms), len(directions)), dtype=float)
    for aa, iat in enumerate(atoms):
        for ii, coord in enumerate(directions):
            for jj in range(2):
                newcoords = np.array(coords)
                newcoords[iat][coord] += float(2 * jj - 1) * disp
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     latvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary])
                fp = open("autotest.tag", "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("iat: %2d, dir: %s, delta: %s" % (iat + 1, cart[coord], delta[jj]))
                if match:
                    energy[jj] = float(match.group("value"))
                    print("energy:", energy[jj])
                else:
                    raise ValueError("No energy match found in autotest.tag!")
            forces[aa][ii] = (energy[0] - energy[1]) / (2.0 * disp)
    return forces


def calculate_latderivs(binary, disp, coords, latvecs, specienames, species,
                        origin, charge_restart, directions):
    """Calculates lattice derivatives by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    energy = np.empty((2,), dtype=float)
    latderivs = np.empty((3, len(directions)), dtype=float)
    for ii, coord in enumerate(directions):
        for jj in range(3):
            for kk in range(2):
                newcoords = np.array(coords)
                newcoords = cart2frac(latvecs, newcoords)
                newvecs = np.array(latvecs)
                newvecs[jj][coord] += float(2 * kk - 1) * disp
                newcoords = frac2cart(newvecs, newcoords)
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     newvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary,])
                fp = open("autotest.tag", "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("dir: %s, ilatvec: %2d, delta: %s" % (cart[coord], jj + 1, delta[kk]))
                print("energy:", energy[kk])
                if match:
                    energy[kk] = float(match.group("value"))
                else:
                    raise ValueError("No energy match found in autotest.tag!")
            latderivs[jj][ii] = (energy[1] - energy[0]) / (2.0 * disp)
    return latderivs


def print_forces(forces, forces0, directions, atoms):
    """Prints calculates forces"""

    cart = ('x', 'y', 'z')
    dir_str = ("%s" * len(directions)) % tuple(cart[x] for x in directions)
    num = len(directions)
    print("Forces by finite differences ("+dir_str+"):")
    for ii, atforce in enumerate(forces):
        print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num % tuple(atforce)))
    if forces0 is not None:
        forces1 = np.empty((forces.shape))
        for aa, atom in enumerate(atoms):
            for ii, coord in enumerate(directions):
                forces1[aa, ii] = forces0[atom, coord]
        print("Reference forces:")
        for ii, atforce in enumerate(forces1):
            print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num) % tuple(atforce))
        print("Difference between obtained and reference forces:")
        diff = forces - forces1
        for ii, idiff in enumerate(diff):
            print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num) % tuple(idiff))
        print("Max diff in any force component:")
        print("%25.12E" % (abs(diff).max(),))


def print_latderivs(latderivs, latderivs0, directions):
    """Prints calculated lattice derivatives."""

    cart = ('x', 'y', 'z')
    dir_str = ("%s" * len(directions)) % tuple(cart[x] for x in directions)
    num = len(directions)
    print("Lattice derivatives by finite differences ("+dir_str+"):")
    for ii in range(3):
        print(("%25.12E" * num) % tuple(latderivs[ii]))
    if latderivs0 is not None:
        latderivs1 = np.empty((latderivs.shape))
        for ii, coord in enumerate(directions):
            latderivs1[:, ii] = latderivs0[:, coord]
        print("Reference lattice derivatives:")
        for ii in range(3):
            print(("%25.12E" * num) % tuple(latderivs1[ii]))
        print("Difference between obtained and reference lattice derivatives:")
        diff = latderivs - latderivs1
        for ii in range(3):
            print(("%25.12E" * num) % tuple(diff[ii]))
        print("Max diff in any lattice derivatives component:")
        print("%25.12E" % (abs(diff).max(), ))


def directions_from_args(coords):
    """Determines the directions of the calculation"""
    dirstrs = {'x': 0, 'y': 1, 'z': 2}
    directions = set()
    for coord in coords:
        direction = dirstrs.get(coord)
        if direction is None:
            raise ValueError(f"'{coord}’ is no valid direction!")
        directions.add(direction)
    return sorted(directions)


def atoms_from_args(inp, maxlen):
    """Determines the atoms for the calculation"""
    atoms = []
    if inp is not None:
        inp_list = inp.split(",")
        for item in inp_list:
            if ":" in item:
                if "-" in item:
                    start, stop = item.split(":")
                    stop = int(stop) + maxlen
                else:
                    start, stop = item.split(":")
                for num in range(int(start), int(stop)+1):
                    atoms.append(num-1)
            else:
                atoms.append(int(item)-1)
        atoms.sort()
    else:
        atoms = list(range(maxlen))
    return atoms

if __name__ == "__main__":
    main()
