#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
"""
Calculates the elastic stiffness tensor by finite differences of the stress tensor.
"""

import argparse
from derivtools import readgen, writegen, cart2frac, frac2cart, stress2latderivs, AA__BOHR, exists, create_temporary_copy
import numpy as np
import re
import subprocess
import shutil

PRESURE_AU = 0.339893208050290E-13 # 1 Pa in atomic units of pressure

DESCRIPTION = """
Calculates Cijkl tensor in the Voight form, C_\alpha\beta, using the
specified DFTB+ binary and applying finite difference strain
displacements to the reference geometry. The geometry of the
configuration must be specified in a file called
'geo.gen.template'. The DFTB+ input file should include the geometry
from the file 'geo.gen'. The input file must specify the option for
writing an autotest file and the option for calculating the forces, as
this also outputs the stress tensor. The lattice vectors should be
fixed in the dftb_in.hsd file, but the atom positions (internal
coordinates) can either be clamped or relaxed depending on which is
required.
"""

STRESS_PATTERN = re.compile(r"stress\s*:[^:]*:[^:]*:3,3\n\s*(?P<value11>\S+)\s+(?P<value21>\S+)\s+(?P<value31>\S+)\n\s*(?P<value12>\S+)\s+(?P<value22>\S+)\s+(?P<value32>\S+)\n\s*(?P<value13>\S+)\s+(?P<value23>\S+)\s+(?P<value33>\S+)\n")

CHARGES_FILE = "charges.bin"

def main():
    """Main routine"""

    args = parse_arguments()
    specienames, species, coords, origin, latvecs = readgen("geo.gen.template")
    if latvecs is None:
        raise ValueError("Structure must be periodic for elastic constants!")
    if (exists(CHARGES_FILE)):
        chargeRestart = create_temporary_copy(CHARGES_FILE)
    else:
        chargeRestart = None
    disp = args.disp * AA__BOHR
    binary = args.binary
    print("BINARY:", binary)

    elastic = calculate_elastic( binary, disp, coords, latvecs,
                                 specienames, species, origin, chargeRestart)

    print_elastic(elastic)

    if (chargeRestart):
        shutil.copy2 (chargeRestart.name, CHARGES_FILE)

def parse_arguments():
    """Parses command line arguments"""

    parser = argparse.ArgumentParser(description=DESCRIPTION)

    msg = "Specify the displacement of the atoms (unit: ANGSTROM)"
    parser.add_argument("-d", "--displacement", type=float, dest="disp",
                        default=1e-5, help=msg)

    msg = "DFTB+ binary"
    parser.add_argument("binary", help=msg)

    args = parser.parse_args()
    return args


def calculate_elastic(binary, disp, coords, latvecs, specienames, species,
                        origin, chargeRestart):
    """Calculates elastic constants by finite differences"""

    # to hold stress tensor
    stress = np.empty((6, 2), dtype=float)
    # to hold resulting elastic tensor
    elastic = np.empty((6,6), dtype=float)

    # voight convention for 2 index to 1 of tensors
    voight = [[0,0],[1,1],[2,2],[1,2],[0,2],[0,1]]

    for ii in range(6):
        for kk in range(2):
            strain = np.zeros((3,3), dtype = float)
            for jj in range (3) :
                strain[jj][jj] = 1.0
            strain[voight[ii][0]][voight[ii][1]] += 0.5*float(2 * kk - 1) * disp
            strain[voight[ii][1]][voight[ii][0]] += 0.5*float(2 * kk - 1) * disp
            newcoords = np.array(coords)
            newcoords = cart2frac(latvecs,newcoords)
            newvecs = np.array(latvecs)
            newvecs = np.dot(newvecs,strain)
            newcoords = frac2cart(newvecs,newcoords)
            writegen("geo.gen", (specienames, species, newcoords, origin, newvecs))
            if (chargeRestart):
                shutil.copy2 (chargeRestart.name, CHARGES_FILE)
            subprocess.call([binary,])
            fp = open("autotest.tag", "r")
            txt = fp.read()
            fp.close()
            match = STRESS_PATTERN.search(txt)
            if match:
                stress[0,kk] = float(match.group("value11"))
                stress[1,kk] = float(match.group("value22"))
                stress[2,kk] = float(match.group("value33"))
                stress[3,kk] = 0.5*float(match.group("value23"))
                stress[3,kk] += 0.5*float(match.group("value32"))
                stress[4,kk] = 0.5*float(match.group("value13"))
                stress[4,kk] += 0.5*float(match.group("value31"))
                stress[5,kk] = 0.5*float(match.group("value12"))
                stress[5,kk] += 0.5*float(match.group("value21"))
            else:
                raise ValueError("Stress tensor missing from autotest.tag!")
            for jj in range(6):
                elastic[ii][jj] = (stress[jj][0] - stress[jj][1]) / (2.0 * disp)

    return elastic


def print_elastic(elastic):
    """Prints calculated elastic constants."""

    for ii in range(6):
        for jj in range(6):
            elastic[ii][jj] /= (PRESURE_AU * 1.0E9) # Convert to GPa
    print("\nElastic constants (GPa):")
    for ii in range(6):
        print (("%12.4f"*6) % tuple(elastic[ii]))


if __name__ == "__main__":
    main()
