#!/usr/bin/env python3

"""Simple script to sort module imports in Fortran files"""

import argparse
import re
import sys
from typing import Dict, List, Tuple

_DESCRIPTION = "Sorts the 'use' statements in DFTB+ Fortran file"

_PAT_USE_MODULE = re.compile(
    r"""^(?P<indent>[ ]*)use
    (?P<attrib>(?:\s*,\s*intrinsic)?)
    (?P<separator>[ ]*:: |[ ]+)
    (?P<name>\w+)
    (?P<rest>.*?(?:&[ ]*\n(?:[ ]*&)?.*?)*)\n
    """, re.VERBOSE | re.MULTILINE)


def main():
    """Main script driver."""

    args = _parse_arguments()
    for fname in args.filenames:
        txt = open(fname, "r").read()
        blocks, output = _process_file_content(txt, fname)
        open(fname, "w").write("\n".join(output) + "\n")
        if blocks > 1:
            print(f"{fname}: multiple blocks found!")


def _parse_arguments() -> argparse.Namespace:
    """Parses the command line arguments"""

    parser = argparse.ArgumentParser(description=_DESCRIPTION)
    msg = "File to process"
    parser.add_argument("filenames", nargs="+", metavar="FILE", help=msg)
    args = parser.parse_args()
    return args


def _process_file_content(txt: str, fname: str) -> Tuple[int, List[str]]:
    """Processes the content of a file."""

    output = []
    matches = [(match.group('name').lower(), match)
               for match in _PAT_USE_MODULE.finditer(txt)]
    lastpos = 0
    buffer = {}
    blocks = 0
    for name, match in matches:
        if match.start() != lastpos:
            if buffer:
                output += _get_sorted_modules(buffer)
                blocks += 1
            buffer = {}
            output.append(txt[lastpos : match.start()].rstrip())
        if name in buffer:
            _fatal_error(f"{fname}: multiple occurrences of module '{name}'!")
        buffer[name] = match
        lastpos = match.end()
    if buffer:
        output += _get_sorted_modules(buffer)
        blocks += 1
    output.append(txt[lastpos : ].rstrip())
    return blocks, output


def _get_sorted_modules(modules: Dict[str, re.Match]) -> List[str]:
    """Returns the sorted representations of modules"""

    intrinsic_modules = []
    third_party_modules = []
    dftbplus_modules = []
    for name in modules:
        if name.startswith('dftbplus_'):
            dftbplus_modules.append(name)
        elif name.startswith('iso_') or name == 'mpi':
            intrinsic_modules.append(name)
        else:
            third_party_modules.append(name)
    intrinsic_modules.sort()
    third_party_modules.sort()
    dftbplus_modules.sort()

    output = []
    for name in intrinsic_modules + third_party_modules + dftbplus_modules:
        fields = modules[name]
        output.append(f"{fields['indent']}use{fields['attrib']}" \
            f"{fields['separator']}{fields['name'].lower()}" \
            f"{fields['rest']}")
    return output


def _fatal_error(msg: str):
    """Prints an error message and stops"""

    sys.stderr.write(f"Error: {msg}\n")
    sys.exit(1)


if __name__ == "__main__":
    main()
