#!/usr/bin/env bash
#
# This shell script (nvcc_wrapper) wraps both the host compiler and
# NVCC, if you are building legacy C or C++ code with CUDA enabled.
# The script remedies some differences between the interface of NVCC
# and that of the host compiler, in particular for linking.
# It also means that a legacy code doesn't need separate .cu files;
# it can just use .cpp files.
#
# Default settings: change those according to your machine.  For
# example, you may have have two different wrappers with either icpc
# or g++ as their back-end compiler.  The defaults can be overwritten
# by using the usual arguments (e.g., -arch=sm_80 -ccbin icpc).
# sm_70 is supported by every CUDA version from 9-12 and is thus
# chosen as default

default_arch="sm_70"
#default_arch="sm_80"

#
# The default C++ compiler.
#
host_compiler=${NVCC_WRAPPER_DEFAULT_COMPILER:-"g++"}

# Default to whatever is in the path
nvcc_compiler=nvcc
if [ ! -z $CUDA_ROOT ]; then
  nvcc_compiler="$CUDA_ROOT/bin/nvcc"
fi

#host_compiler="icpc"
#host_compiler="/usr/local/gcc/4.8.3/bin/g++"
#host_compiler="/usr/local/gcc/4.9.1/bin/g++"

#
# Internal variables
#

# C++ files
cpp_files=""

# Host compiler arguments
xcompiler_args=""

# Cuda (NVCC) only arguments
cuda_args=""

# Arguments for both NVCC and Host compiler
shared_args=""

# Argument -c
compile_arg=""

# Argument -o <obj>
output_arg=""

# Linker arguments
xlinker_args=""

# Object files passable to NVCC
object_files=""

# Link objects for the host linker only
object_files_xlinker=""

# Shared libraries with version numbers are not handled correctly by NVCC
shared_versioned_libraries_host=""
shared_versioned_libraries=""

# Does the User set the architecture
arch_set=0
arch_flag=""

# Does the user set RDC?
rdc_set=0
rdc_flag=""

# Does the user overwrite the host compiler
ccbin_set=0

#Error code of compilation
error_code=0

# Do a dry run without actually compiling
dry_run=0

# Skip NVCC compilation and use host compiler directly
host_only=0
host_only_args=""

# Just run version on host compiler
get_host_version=0

# Enable workaround for CUDA 6.5 for pragma ident
replace_pragma_ident=0

# Mark first host compiler argument
first_xcompiler_arg=1

# Allow for setting temp dir without setting TMPDIR in parent (see https://docs.olcf.ornl.gov/systems/summit_user_guide.html#setting-tmpdir-causes-jsm-jsrun-errors-job-state-flip-flop)
if [[ -z ${NVCC_WRAPPER_TMPDIR+x} ]]; then
  temp_dir=${TMPDIR:-/tmp}
else
  temp_dir=${NVCC_WRAPPER_TMPDIR}
fi

# optimization flag added as a command-line argument
optimization_flag=""

# std standard flag added as a command-line argument
std_flag=""

# Run nvcc a second time to generate dependencies if needed
depfile_separate=0
depfile_output_arg=""
depfile_target_arg=""

# Option to remove duplicate libraries and object files
remove_duplicate_link_files=0

function warn_std_flag() {
  echo "nvcc_wrapper - *warning* you have set multiple standard flags (-std=c++1* or --std=c++1*), only the last is used because nvcc can only accept a single std setting"
}

#echo "Arguments: $# $@"

while [ $# -gt 0 ]
do
  case $1 in
  #show the executed command
  --show|--nvcc-wrapper-show)
    dry_run=1
    ;;
  #run host compilation only
  --host-only)
    host_only=1
    ;;
  #get the host version only
  --host-version)
    get_host_version=1
    ;;
  #replace '#pragma ident' with '#ident' this is needed to compile OpenMPI due to a configure script bug and a non standardized behaviour of pragma with macros
  --replace-pragma-ident)
    replace_pragma_ident=1
    ;;
  #remove duplicate link files
  --remove-duplicate-link-files)
    remove_duplicate_link_files=1
    ;;
  #handle source files to be compiled as cuda files
  *.cpp|*.cxx|*.cc|*.C|*.c++|*.cu)
    cpp_files="$cpp_files $1"
    ;;
   # Ensure we only have one optimization flag because NVCC doesn't allow multiple
  -O*)
    if [ -n "$optimization_flag" ]; then
        if [ "$1" = "$optimization_flag" ]; then
            # Silently consume duplicates of the same argument
            shift
            continue
        fi
        echo "nvcc_wrapper - *warning* you have set multiple optimization flags (-O*), only the last is used because nvcc can only accept a single optimization setting."
        shared_args=${shared_args/ $optimization_flag/}
    fi
    if [ "$1" = "-O" ]; then
      optimization_flag="-O2"
    else
      optimization_flag=$1
    fi
    shared_args="$shared_args $optimization_flag"
    ;;
  #Handle shared args (valid for both nvcc and the host compiler)
  -D*)
    unescape_commas=`echo "$1" | sed -e 's/\\\,/,/g'`
    arg=`printf "%q" $unescape_commas`
    shared_args="$shared_args $arg"
    ;;
  -I*|-L*|-l*|-g|--help|--version|-E|-M|-shared|-w)
    shared_args="$shared_args $1"
    ;;
  #Handle compilation argument
  -c)
    compile_arg="$1"
    ;;
  #Handle output argument
  -o)
    output_arg="$output_arg $1 $2"
    shift
    ;;
  # Handle depfile arguments.  We map them to a separate call to nvcc.
  -MD|-MMD)
    depfile_separate=1
    host_only_args="$host_only_args $1"
    ;;
  -MF)
    depfile_output_arg="-o $2"
    host_only_args="$host_only_args $1 $2"
    shift
    ;;
  -MT)
    depfile_target_arg="$1 $2"
    host_only_args="$host_only_args $1 $2"
    shift
    ;;
  # Handle nvcc args controlling whether to generated relocatable device code
  --relocatable-device-code=*|-rdc=*)
    if [ "$rdc_set" -eq 0 ]; then
        rdc_set=1
        rdc_flag="$1"
        cuda_args="$cuda_args $rdc_flag"
    elif [  "$rdc_flag" != "$1" ]; then
        echo "RDC is being set twice with different flags, which is not handled"
        echo "$rdc_flag"
        echo "$1"
        exit 1
    fi
    ;;
  -rdc)
    if [ "$rdc_set" -eq 0 ]; then
        rdc_set=1
        rdc_flag="$1 $2"
        cuda_args="$cuda_args $rdc_flag"
        shift
    elif [ "$rdc_flag" != "$1 $2" ]; then
        echo "RDC is being set twice with different flags, which is not handled"
        echo "$rdc_flag"
        echo "$1 $2"
        exit 1
    fi
    ;;
  #Handle known nvcc args
  --dryrun|--verbose|--keep|--source-in-ptx|-src-in-ptx|--keep-dir*|-G|-lineinfo|-extended-lambda|-expt-extended-lambda|-expt-relaxed-constexpr|--resource-usage|--fmad=*|--use_fast_math|--Wext-lambda-captures-this|-Wext-lambda-captures-this)
    cuda_args="$cuda_args $1"
    ;;
  #Handle more known nvcc args
  --extended-lambda|--expt-extended-lambda|--expt-relaxed-constexpr|--Wno-deprecated-gpu-targets|-Wno-deprecated-gpu-targets|-allow-unsupported-compiler|--allow-unsupported-compiler)
    cuda_args="$cuda_args $1"
    ;;
  #Handle known nvcc args that have an argument
  -maxrregcount=*|--maxrregcount=*|-time=*|-Xptxas=*)
    cuda_args="$cuda_args $1"
    ;;
  -maxrregcount|--default-stream|-Xnvlink|--fmad|-cudart|--cudart|-include|-time|-Xptxas)
    cuda_args="$cuda_args $1 $2"
    shift
    ;;
  # Handle Werror. Note, we must differentiate between the ones going to nvcc and the host compiler
  # --Werror kind,... OR --Werror=kind,... <- always to nvcc
  --Werror)
    cuda_args="$cuda_args $1 $2"
    shift
    ;;
  --Werror=*)
    cuda_args="$cuda_args $1"
    ;;
  # -Werror kind,... where kind is one of {all-warnings, cross-execution-space-call, reorder, default-stream-launch, missing-launch-bounds, ext-lambda-captures-this, deprecated-declarations} <- goes to nvcc
  # -Werror not followed by any kind as mentioned above goes to host compiler without any arguments
  -Werror)
    if [ $# -gt 1 ]; then
      IFS="," read -r -a kinds <<< "$2"
      first_kind=${kinds[0]}
      # check if the first kind is one of the allowed ones, then this must be an nvcc list so put all of them to the cuda compiler
      case $first_kind in
      all-warnings|cross-execution-space-call|reorder|default-stream-launch|missing-launch-bounds|ext-lambda-captures-this|deprecated-declarations)
        cuda_args="$cuda_args $1 $2"
        shift
        ;;
      *)
        if [ $first_xcompiler_arg -eq 1 ]; then
          xcompiler_args="$1"
          first_xcompiler_arg=0
        else
          xcompiler_args="$xcompiler_args,$1"
        fi
        ;;
      esac
    fi
    ;;
  # -Werror=kind,... will be split into two parts, those kinds that belong to nvcc (see above) go there, while all others go towards the host compiler
  -Werror=*)
    kinds_str="${1:8}" # strip -Werror=
    IFS="," read -r -a kinds <<< ${kinds_str}
    first_werror_cuda=1
    first_werror_host=1
    xcompiler_args_werror=
    # loop over all kinds that are sparated via ','
    for kind in "${kinds[@]}"
    do
      case ${kind} in
      all-warnings|cross-execution-space-call|reorder|default-stream-launch|missing-launch-bounds|ext-lambda-captures-this|deprecated-declarations)
        if [ $first_werror_cuda -ne 0 ]; then
          cuda_args="$cuda_args -Werror="
          first_werror_cuda=0
        else
          cuda_args="$cuda_args,"
        fi
        cuda_args="$cuda_args$kind"
        ;;
      *)
        if [ $first_werror_host -eq 0 ]; then
            xcompiler_args_werror="${xcompiler_args_werror},"
        fi
        first_werror_host=0
        xcompiler_args_werror="$xcompiler_args_werror-Werror=$kind"
        ;;
      esac
    done
    if [ $first_werror_host -eq 0 ]; then
      if [ $first_xcompiler_arg -eq 1 ]; then
        xcompiler_args="$xcompiler_args_werror"
        first_xcompiler_arg=0
      else
        xcompiler_args="$xcompiler_args,$xcompiler_args_werror"
      fi
    fi
    ;;
  # End of Werror handling
  #Handle unsupported standard flags
  --std=c++1y|-std=c++1y|--std=gnu++1y|-std=gnu++1y|--std=c++1z|-std=c++1z|--std=gnu++1z|-std=gnu++1z|--std=c++2a|-std=c++2a)
    fallback_std_flag="-std=c++14"
    # this is hopefully just occurring in a downstream project during CMake feature tests
    # we really have no choice here but to accept the flag and change  to an accepted C++ standard
    echo "nvcc_wrapper does not accept standard flags $1 since partial standard flags and standards after C++17 are not supported. nvcc_wrapper will use $fallback_std_flag instead. It is undefined behavior to use this flag. This should only be occurring during CMake configuration."
    if [ -n "$std_flag" ]; then
       warn_std_flag
       shared_args=${shared_args/ $std_flag/}
    fi
    std_flag=$fallback_std_flag
    shared_args="$shared_args $std_flag"
    ;;
  -std=gnu*)
    corrected_std_flag=${1/gnu/c}
    echo "nvcc_wrapper has been given GNU extension standard flag $1 - reverting flag to $corrected_std_flag"
    if [ -n "$std_flag" ]; then
       warn_std_flag
       shared_args=${shared_args/ $std_flag/}
    fi
    std_flag=$corrected_std_flag
    shared_args="$shared_args $std_flag"
    ;;
  --std=c++20|-std=c++20)
    if [ -n "$std_flag" ]; then
      warn_std_flag
      shared_args=${shared_args/ $std_flag/}
    fi
    # NVCC only has C++20 from version 12 on
    cuda_main_version=$([[ $(${nvcc_compiler} --version) =~ V([0-9]+) ]] && echo ${BASH_REMATCH[1]})
    if [ ${cuda_main_version} -lt 12 ]; then
      fallback_std_flag="-std=c++14"
      # this is hopefully just occurring in a downstream project during CMake feature tests
      # we really have no choice here but to accept the flag and change  to an accepted C++ standard
      echo "nvcc_wrapper does not accept standard flags $1 since partial standard flags and standards after C++14 are not supported. nvcc_wrapper will use $fallback_std_flag instead. It is undefined behavior to use this flag. This should only be occurring during CMake configuration."
      std_flag=$fallback_std_flag
    else
      std_flag=$1
    fi
    shared_args="$shared_args $std_flag"
    ;;
  --std=c++17|-std=c++17)
    if [ -n "$std_flag" ]; then
      warn_std_flag
      shared_args=${shared_args/ $std_flag/}
    fi
    # NVCC only has C++17 from version 11 on
    cuda_main_version=$([[ $(${nvcc_compiler} --version) =~ V([0-9]+) ]] && echo ${BASH_REMATCH[1]})
    if [ ${cuda_main_version} -lt 11 ]; then
      fallback_std_flag="-std=c++14"
      # this is hopefully just occurring in a downstream project during CMake feature tests
      # we really have no choice here but to accept the flag and change  to an accepted C++ standard
      echo "nvcc_wrapper does not accept standard flags $1 since partial standard flags and standards after C++14 are not supported. nvcc_wrapper will use $fallback_std_flag instead. It is undefined behavior to use this flag. This should only be occurring during CMake configuration."
      std_flag=$fallback_std_flag
    else
      std_flag=$1
    fi
    shared_args="$shared_args $std_flag"
    ;;
  --std=c++11|-std=c++11|--std=c++14|-std=c++14)
    if [ -n "$std_flag" ]; then
       warn_std_flag
       shared_args=${shared_args/ $std_flag/}
    fi
    std_flag=$1
    shared_args="$shared_args $std_flag"
    ;;

  #convert PGI standard flags to something nvcc can handle
  --c++11|--c++14|--c++17)
    if [ -n "$std_flag" ]; then
       warn_std_flag
       shared_args=${shared_args/ $std_flag/}
    fi
    std_flag="-std=${1#--}"
    shared_args="$shared_args $std_flag"
    ;;

  #ignore PGI forcing ISO C++-conforming code
  -A)
    ;;

  #strip of -std=c++98 due to nvcc warnings and Tribits will place both -std=c++11 and -std=c++98
  -std=c++98|--std=c++98)
    ;;
  #strip of pedantic because it produces endless warnings about #LINE added by the preprocessor
  -pedantic|-pedantic-errors|-Wpedantic|-ansi)
    ;;
  #strip of -Woverloaded-virtual to avoid "cc1: warning: command line option ‘-Woverloaded-virtual’ is valid for C++/ObjC++ but not for C"
  -Woverloaded-virtual)
    ;;
  #strip -Xcompiler because we add it
  -Xcompiler|--compiler-options)
    if [[ $2 != "-o" ]]; then
      if [ $first_xcompiler_arg -eq 1 ]; then
        xcompiler_args="$2"
        first_xcompiler_arg=0
      else
        xcompiler_args="$xcompiler_args,$2"
      fi
      shift
    fi
    # else this we have -Xcompiler -o <filename>, in this case just drop -Xcompiler and process
    # the -o flag with the filename (done above)
    ;;
  #strip of "-x cu" because we add that
  -x)
    if [[ $2 != "cu" ]]; then
      if [ $first_xcompiler_arg -eq 1 ]; then
        xcompiler_args="-x,$2"
        first_xcompiler_arg=0
      else
        xcompiler_args="$xcompiler_args,-x,$2"
      fi
    fi
    shift
    ;;
  #Handle -+ (same as -x c++, specifically used for xl compilers, but mutually exclusive with -x. So replace it with -x c++)
  -+)
    if [ $first_xcompiler_arg -eq 1 ]; then
      xcompiler_args="-x,c++"
      first_xcompiler_arg=0
    else
      xcompiler_args="$xcompiler_args,-x,c++"
    fi
    ;;
  #Handle -ccbin (if its not set we can set it to a default value)
  -ccbin)
    cuda_args="$cuda_args $1 $2"
    ccbin_set=1
    host_compiler=$2
    shift
    ;;

  #Handle -arch argument (if its not set use a default) this is the version with = sign
  -arch=*|-gencode=*)
    if [ "$arch_set" -eq 0 ]; then
        arch_set=1
        arch_flag="$1"
        cuda_args="$cuda_args $arch_flag"
    elif [  "$arch_flag" != "$1" ]; then
        echo "ARCH is being set twice with different flags, which is not handled"
        echo "$arch_flag"
        echo "$1"
        exit 1
    fi
    ;;
  #Handle -arch argument (if its not set use a default) this is the version without = sign
  -arch|-gencode)
    if [ "$arch_set" -eq 0 ]; then
        arch_set=1
        arch_flag="$1 $2"
        cuda_args="$cuda_args $arch_flag"
        shift
    elif [ "$arch_flag" != "$1 $2" ]; then
        echo "ARCH is being set twice with different flags, which is not handled"
        echo "$arch_flag"
        echo "$1 $2"
        exit 1
    fi
    ;;
  #Handle -code argument (if its not set use a default) this is the version with = sign
  -code*)
    cuda_args="$cuda_args $1"
    ;;
  #Handle -code argument (if its not set use a default) this is the version without = sign
  -code)
    cuda_args="$cuda_args $1 $2"
    shift
    ;;
  #Handle -Xcudafe argument
  -Xcudafe)
    cuda_args="$cuda_args -Xcudafe $2"
    shift
    ;;
  #Handle -Xlinker argument
  -Xlinker)
    xlinker_args="$xlinker_args -Xlinker $2"
    shift
    ;;
  #Handle args that should be sent to the linker
  -Wl,*)
    xlinker_args="$xlinker_args -Xlinker ${1:4:${#1}}"
    host_linker_args="$host_linker_args ${1:4:${#1}}"
    ;;
  #Handle object files: -x cu applies to all input files, so give them to linker, except if only linking
  *.a|*.so|*.o|*.obj)
    object_files="$object_files $1"
    object_files_xlinker="$object_files_xlinker -Xlinker $1"
    ;;
  #Handle object files which always need to use "-Xlinker": -x cu applies to all input files, so give them to linker, except if only linking
  @*|*.dylib)
    object_files="$object_files -Xlinker $1"
    object_files_xlinker="$object_files_xlinker -Xlinker $1"
    ;;
  #Handle shared libraries with *.so.* names which nvcc can't do.
  *.so.*)
    shared_versioned_libraries_host="$shared_versioned_libraries_host $1"
    shared_versioned_libraries="$shared_versioned_libraries -Xlinker $1"
  ;;
  #All other args are sent to the host compiler
  *)
    if [ $first_xcompiler_arg -eq 1 ]; then
      xcompiler_args=$1
      first_xcompiler_arg=0
    else
      xcompiler_args="$xcompiler_args,$1"
    fi
    ;;
  esac

  shift
done

# Only print host compiler version
if [ $get_host_version -eq 1 ]; then
  $host_compiler --version
  exit
fi

#Remove duplicate object files
if [ $remove_duplicate_link_files -eq 1 ]; then
for obj in $object_files
do
  object_files_reverse="$obj $object_files_reverse"
done

object_files_reverse_clean=""
for obj in $object_files_reverse
do
  exists=false
  for obj2 in $object_files_reverse_clean
  do
    if [ "$obj" == "$obj2" ]
    then
      exists=true
      echo "Exists: $obj"
    fi
  done
  if [ "$exists" == "false" ]
  then
    object_files_reverse_clean="$object_files_reverse_clean $obj"
  fi
done

object_files=""
for obj in $object_files_reverse_clean
do
  object_files="$obj $object_files"
done
fi

#Add default host compiler if necessary
if [ $ccbin_set -ne 1 ]; then
  cuda_args="$cuda_args -ccbin $host_compiler"
fi

#Add architecture command
if [ $arch_set -ne 1 ]; then
  cuda_args="$cuda_args -arch=$default_arch"
fi

#Compose compilation command
nvcc_command="$nvcc_compiler $cuda_args $shared_args $xlinker_args $shared_versioned_libraries"
if [ $first_xcompiler_arg -eq 0 ]; then
  nvcc_command="$nvcc_command -Xcompiler $xcompiler_args"
fi

#Replace all commas in xcompiler_args with a space for the host only command
xcompiler_args=${xcompiler_args//,/" "}

#Compose host only command
host_command="$host_compiler $shared_args $host_only_args $compile_arg $output_arg $xcompiler_args $host_linker_args $shared_versioned_libraries_host"

#nvcc does not accept '#pragma ident SOME_MACRO_STRING' but it does accept '#ident SOME_MACRO_STRING'
if [ $replace_pragma_ident -eq 1 ]; then
  cpp_files2=""
  for file in $cpp_files
  do
    var=`grep pragma ${file} | grep ident | grep "#"`
    if [ "${#var}" -gt 0 ]
    then
      sed 's/#[\ \t]*pragma[\ \t]*ident/#ident/g' $file > $temp_dir/nvcc_wrapper_tmp_$file
      cpp_files2="$cpp_files2 $temp_dir/nvcc_wrapper_tmp_$file"
    else
      cpp_files2="$cpp_files2 $file"
    fi
  done
  cpp_files=$cpp_files2
  #echo $cpp_files
fi

if [ "$cpp_files" ]; then
  nvcc_command="$nvcc_command $object_files_xlinker -x cu $cpp_files"
else
  nvcc_command="$nvcc_command $object_files"
fi

if [ "$cpp_files" ]; then
  host_command="$host_command $object_files $cpp_files"
else
  host_command="$host_command $object_files"
fi

if [ $depfile_separate -eq 1 ]; then
  # run nvcc a second time to generate dependencies (without compiling)
  nvcc_depfile_command="$nvcc_command -M $depfile_target_arg $depfile_output_arg"
else
  nvcc_depfile_command=""
fi

nvcc_command="$nvcc_command $compile_arg $output_arg"

#Print command for dryrun
if [ $dry_run -eq 1 ]; then
  if [ $host_only -eq 1 ]; then
    echo $host_command
  elif [ -n "$nvcc_depfile_command" ]; then
    echo $nvcc_command "&&" $nvcc_depfile_command
  else
    echo $nvcc_command
  fi
  exit 0
fi

#Run compilation command
if [ $host_only -eq 1 ]; then
  if [ "$NVCC_WRAPPER_SHOW_COMMANDS_BEING_RUN" == "1" ] ; then
    echo "$host_command"
  fi
  $host_command
elif [ -n "$nvcc_depfile_command" ]; then
  if [ "$NVCC_WRAPPER_SHOW_COMMANDS_BEING_RUN" == "1" ] ; then
    echo "TMPDIR=${temp_dir} $nvcc_command && TMPDIR=${temp_dir} $nvcc_depfile_command"
  fi
  TMPDIR=${temp_dir} $nvcc_command && TMPDIR=${temp_dir} $nvcc_depfile_command
else
  if [ "$NVCC_WRAPPER_SHOW_COMMANDS_BEING_RUN" == "1" ] ; then
    echo "TMPDIR=${temp_dir} $nvcc_command"
  fi
  TMPDIR=${temp_dir} $nvcc_command
fi
error_code=$?

#Report error code
exit $error_code
