#!/usr/bin/env python3
from __future__ import print_function
import os
import sys
import subprocess
import os.path


def cpp_arg(arg):
    return \
        arg.startswith("-I") or \
        arg.startswith("-D") or \
        arg.startswith("-U")


def check_call(args, **kwargs):
    if os.getenv("V") == "1":
        print(" ".join(args))
    return subprocess.check_call(args, **kwargs)


def check_call_redirect(args, filename=None, **kwargs):
    if os.getenv("V") == "1":
        print(" ".join(args), ">", filename)
    with open(filename, "wb") as fd:
        try:
            return subprocess.check_call(args, stdout=fd, **kwargs)
        except subprocess.CalledProcessError as e:
            os.remove(filename)
            raise SystemExit(e.returncode)

def is_nvidia_backend():
    try:
        output = subprocess.check_output(["hipcc", "--version"], universal_newlines=True)

        if "nvcc" in output.lower():
            return True
        else:
            return False
    except (subprocess.CalledProcessError, FileNotFoundError):
        # Could not run hipcc --version
        return False

args = sys.argv[1:]
cpp_args = list(filter(cpp_arg, args))
files = list(filter(lambda q: q.endswith((".F90", ".cu", ".hip")), args))
args = list(filter(lambda q: not q.endswith((".F90", ".cu", ".hip")), args))

if len(files) > 1:
    raise Exception("Specify exactly one source file (.F90, .cu, or .hip)")
elif len(files) == 0:
    # No ".F90", ".cu", ".hip" file specified, execute program as-is
    try:
        os.execvp(args[0], args[0:])
    except OSError as e:
        print("Error executing '{0}': {1}".format(args[0], e.args[1]))
        raise SystemExit(1)
elif len(files) == 1:
    file, = files

# extract the base name and extension of the source file
base, ext = os.path.splitext(file)

tmp_filename = "manually_preprocessed_" + file.replace("/", "_")

try:
    output = args.index("-o")
    outputname = args[output + 1]
    tmp_filename += "-" + outputname.replace("/", "_") + ext
except ValueError:
    pass

tmp_filename = tmp_filename.replace("-", "_")
tmp_filename = tmp_filename[-250:]

if (ext==".cu"):
    # preprocess
    list_preprocess_call_cu  = ["nvcc"]  + cpp_args + ["-E"] + [file] + ["-o"] + [tmp_filename]
    check_call(list_preprocess_call_cu)

    # compile
    check_call(sys.argv[1:])
elif (ext==".hip"):
    compiler="hipcc"
    hip_on_nvidia_args = []
    if (is_nvidia_backend()):
        compiler="nvcc"
        hip_on_nvidia_args.append("-x")
        hip_on_nvidia_args.append("cu")

    list_preprocess_call_hip = [compiler] + hip_on_nvidia_args + cpp_args + ["-E"] + [file] + ["-o"] + [tmp_filename]
    check_call(list_preprocess_call_hip)

    # compile
    check_call(sys.argv[1:])


elif (ext==".F90"):
    # preprocess
    check_call_redirect(["cpp", "-P", "-traditional", "-Wall", "-Werror"] + cpp_args + [file], filename=tmp_filename)

    # compile
    check_call(args + [tmp_filename])

# cleanup (may be commented out for better debuggability)
os.remove(tmp_filename)
