#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#

############################################################################
#
# get_neginvjac -- Returns negative inverse Jacobian
#
############################################################################
import numpy as np
import numpy.linalg as la
import subprocess
import re
import shutil
import argparse


DFTB_IN_TEMPLATE = "dftb_in.hsd.template"
DFTB_IN = "dftb_in.hsd"
RESULTS_TAG = "results.tag"
INVERSE_JACOBIAN_FNAME = "neginvjac.dat"

DFTB_BEGIN = """
+Hamiltonian {
  +DFTB {
"""

DFTB_CHARGES = """
    Charge = {totalcharge:e}
    !InitialCharges {{
      AllAtomCharges {{
        {allatomcharges:s}
      }}
    }}
"""

DFTB_SCCTOL = """
    !SCCTolerance = {scctol:e}
"""

DFTB_MAXSCCITER = """
    !MaxSCCIterations = {maxscciter:d}
"""

DFTB_END = """
  }
}
"""

INPUT_PREFACS = { 2: [ -1.0, 1.0 ] }

OUTPUT_PREFACS = { 2: [ -0.5, 0.5 ] }


ATOMIC_CHARGES_PATTERN = re.compile(
    r"net_atomic_charges[^:]*:[^:]*:[^:]*:\d+\s*"
    "(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)",
    re.MULTILINE)


def parse_command_line():
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", dest="diff", type=float, default=1e-4,
                        help="Finite difference")
    parser.add_argument("binary", help="DFTB+ binary")
    parser.add_argument("natom", type=int, help="Number of atoms")

    return parser.parse_args()


def create_input(charges=None, scctol=None, maxscciter=None):
    shutil.copy(DFTB_IN_TEMPLATE, DFTB_IN)
    fp = open(DFTB_IN, "a+")
    fp.write(DFTB_BEGIN)
    if charges is not None:
        #totalcharge = np.sum(charges)
        totalcharge = 0.0
        charges_str = "\n".join(
            [ "{:23.15E}".format(charge) for charge in charges ])
        fp.write(DFTB_CHARGES.format(totalcharge=totalcharge,
                                     allatomcharges=charges_str))
    if scctol is not None:
        fp.write(DFTB_SCCTOL.format(scctol=scctol))
    if maxscciter is not None:
        fp.write(DFTB_MAXSCCITER.format(maxscciter=maxscciter))
    fp.write(DFTB_END)
    fp.close()


def get_output_charges():
    fp = open(RESULTS_TAG, "r")
    txt = fp.read()
    fp.close()
    match = ATOMIC_CHARGES_PATTERN.search(txt)
    values = match.group("values")
    charges = np.array(values.split(), dtype=float)
    return charges


def do_calculation(binary):
    subprocess.call([ binary, ])


def write_inverse_jacobian_kernel(invjac):
    fp = open(INVERSE_JACOBIAN_FNAME, "w")
    for row in invjac:
        for elem in row:
            fp.write("{:25.13E}".format(elem))
        fp.write("\n")
    fp.close()
    print("*** Negative inverse Jacobian written to '{}'".format(
        INVERSE_JACOBIAN_FNAME))


def main():

    args = parse_command_line()
    binary = args.binary
    natom = args.natom
    diff = args.diff
    order = 2
    input_prefacs = INPUT_PREFACS[order]
    output_prefacs = OUTPUT_PREFACS[order]

    create_input()
    do_calculation(binary)
    charges0 = get_output_charges()

    derivmtx = np.empty(( natom, natom ), dtype=float)
    for iatom in range(natom):
        print("*** {:d} / {:d}".format(iatom + 1, natom))
        derivs = np.zeros(( natom, ), dtype=float)
        for ind in range(order):
            charges = np.array(charges0)
            charges[iatom] += input_prefacs[ind] * diff
            create_input(charges, scctol=1e10, maxscciter=1)
            do_calculation(binary)
            derivs += output_prefacs[ind] * get_output_charges()
        derivs /= diff
        derivmtx[:,iatom] = derivs

    maxdev = np.max(np.abs(np.sum(derivmtx, axis=0)))
    print("\nMax column sum deviaton in deriv. matrix:\t", maxdev)
    invjac = -1.0 * la.inv(derivmtx - np.eye(natom, dtype=float))
    columnsums = np.sum(invjac, axis=0)
    maxdev = np.max(np.abs(columnsums - 1.0))
    print("Max column sum deviation in inv. Jacobian:\t", maxdev)
    invjac /= columnsums
    write_inverse_jacobian_kernel(invjac)


if __name__ == "__main__":
  main()
