#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2023  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
'''
Fourier transform of DFTB+ TD data
'''

import sys
import argparse
import numpy as np
from scipy import constants

DESCRIPTION = """Reads output from TD calculation (after kick) and produces absorption spectra.
Needs mux.dat, muy.dat and muz.dat files in working directory."""

def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument("-d", "--damping", dest="tau", required=True,
                        type=float, help="damping constant in fs")
    parser.add_argument("-f", "--field", dest="field", required=True,
                        type=float, help="field intensity in volts/angstrom")

    args = parser.parse_args()

    mux = np.loadtxt('mux.dat') # response to excitation in x direction
    muy = np.loadtxt('muy.dat') # response to excitation in y direction
    muz = np.loadtxt('muz.dat') # response to excitation in z direction

    if (mux.shape[1] > 4):
        spinpol = True
        print('This propagation was done using collinear spin polarisation')
        spintype = str(input('Please select singlet or triplet excitations (s/t)'))
    else:
        spinpol = False

    damp = np.exp(-mux[:,0]/ args.tau)
    if not spinpol:
        average = ((mux[:,1]-mux[0,1]) + (muy[:,2]-muy[0,2]) + (muz[:,3]-muz[0,3]))/3.0
    elif (spintype == 's'):
        average = ((mux[:,1]+mux[:,4]-mux[0,1]-mux[0,4]) + (muy[:,2]+muy[:,5]-muy[0,2]-muy[0,5]) +\
                   (muz[:,3]+muz[:,6]-muz[0,3]-muz[0,6]))/3.0
    elif (spintype == 't'):
        cspeed_au = 137.
        average = ((mux[:,1]-mux[:,4]-mux[0,1]+mux[0,4]) + (muy[:,2]-muy[:,5]-muy[0,2]+muy[0,5]) +\
                   (muz[:,3]-muz[:,6]-muz[0,3]+muz[0,6]))/3.0/cspeed_au

    spec = np.fft.rfft(damp*average, 10*mux.shape[0])
    hplanck = constants.physical_constants['Planck constant in eV s'][0] * 1.0E15
    cspeednm = constants.speed_of_light * 1.0e9 / 1.0e15
    energsev = np.fft.rfftfreq(10*mux.shape[0], mux[1,0]-mux[0,0]) * hplanck
    frec = np.fft.rfftfreq(10*mux.shape[0], (mux[1,0] - mux[0,0]) * 1.0E-15 )
    absorption = -2.0 * energsev * spec.imag / np.pi / args.field
    energsnm = constants.nu2lambda(frec[1:]) * 1.0E9

    emin = 0.5
    emax = 30.
    wvlmin = hplanck * cspeednm / emax
    wvlmax = hplanck * cspeednm / emin

    np.savetxt('spec-ev.dat',np.column_stack((energsev[(energsev>emin) & (energsev<emax)], \
                                              absorption[(energsev>emin) & (energsev<emax)])))
    np.savetxt('spec-nm.dat',np.column_stack((energsnm[(energsnm>wvlmin)&(energsnm<wvlmax)], \
                                              absorption[1:][(energsnm>wvlmin)&(energsnm<wvlmax)])))


if __name__ == "__main__":
    main()
