from __future__ import print_function
from builtins import object
from builtins import range
from builtins import str
import argparse
import errno
import glob
import logging
import os
import shutil
import subprocess
import sys
import tempfile

import base.proto
import pipedream.api

from base.file.utils import file_utils
from mapping.scripts.ll3d_pipeline import mapper_log_pb2
from mapping.distributed_mapping.proto import map_tile_info_pb2
from pipedream.api.job_from_bazel_target import get_job_from_bazel_target
from vehicle.perception.ll3d.proto import tile_info_pb2


def add_task_arguments(parser):
    parser.add_argument('--task_idx', type=int, default=os.getenv('PIPEDREAM_TASK_INDEX', 0))
    parser.add_argument('--task_count', type=int, default=os.getenv('PIPEDREAM_TASK_COUNT', 1))


class BazelTarget(object):
    def __init__(self, package_path, target_name):
        self.path = package_path
        self.name = target_name
    def get_target(self):
        return "//" + self.path + ":" + self.name
    def get_path_to_bin(self):
        return os.path.join(self.path, self.name)
    def get_name(self):
        return self.name


def find_file(dir_name, pattern):
    """Returns the file matching the given pattern in the given repository,
    asserting that there is one and only one matching file
    """
    assert os.path.exists(dir_name), "Input directory " + dir_name + " does not exist"
    assert os.path.isdir(dir_name), "Input directory " + dir_name + " is not a valid directory"
    files = glob.glob(os.path.join(dir_name, pattern))
    assert len(files) > 0, "Found no file matching pattern '{}' in {}".format(pattern, dir_name)
    assert len(files) == 1, "Found more that one file matching pattern '{}' in {}".format(pattern, dir_name)
    f = files[0]
    assert os.path.isfile(f), "Found " + f + " but it's not a valid file"
    return f


class Task(object):
    def __init__(self, cmd=[], files_to_clean_on_failure=[]):
        self.cmd = cmd
        self.logs = []
        self.files_to_clean_on_failure = files_to_clean_on_failure


class Pipeline(object):
    def __init__(self, wdir=None, name=None, local=False):
        self.wdir = wdir
        self.local = local
        self.name = name
        if self.local:
            self.pipe = []
        else:
            self.pipe = {
                "name": self.name,
                "job": [],
            }

    def make_job(self, bazel_target, branch=None, gpu=0, args=None):
        """Creates and returns a job for the given target.

        'bazel_target' is assumed to be of type BazelTarget.

        If 'branch' is not given, it pulls the local branch with the local changes.
        Otherwise it runs on the given branch which is assumed to be present on the
        remote driving repo.

        'args' is an array with the arguments for the target.

        If local is True, then this will not create a pipedream job description
        but rather return a command that can be run by subprocess.
        """

        if not self.local:
            build_opts = ['-c opt']
            if gpu > 0:
                build_opts.append('--config=cuda')

            if branch:
                job = {
                    'name': bazel_target.get_name(),
                    'command': './' + bazel_target.get_name(),
                    'branch': branch,
                    'bazelTarget': [{'target': bazel_target.get_target(), 'buildOpt': build_opts}],
                }
            else:
                job = get_job_from_bazel_target(
                        target=bazel_target.get_target(), bazel_args=' '.join(build_opts))

            if gpu > 0:
                job['slurmOption'] = ['--gres=gpu:{}'.format(gpu)]

            job['dependency'] = []

            logging.info(job)

        else:
            # local job
            job = {
                'name': bazel_target.get_name(),
                'command': bazel_target.get_path_to_bin(),
            }

        if args:
            job['arg'] = args

        return job

    def append_job(self, job, dep_prefix=None, dep_type="AFTER_OK"):
        """Appends a job to the pipeline. If `dep_prefix` is provided, then
        add a dependency on all jobs whose name starts with the prefix. If no
        `dep_prefix` is provided, then the job will be assumed to be independent
        of all other jobs and will have no dependencies.
        """
        if job:
            if not self.local:
                if dep_prefix:
                    for stage in self.pipe["job"]:
                        if stage["name"].startswith(dep_prefix):
                            job['dependency'].append({"jobName": stage["name"], "dependencyType": dep_type})
                self.pipe["job"].append(job)
            else:
                self.pipe.append(job)


    def append_jobs(self, jobs, dep_prefix=None):
        """Appends some jobs to the pipeline.  If `dep_prefix` is provided,
        then add a dependency on all jobs whose name starts with the prefix.
        If no `dep_prefix` is provided, then the job will be assumed to be
        independent of all other jobs and will have no dependencies.
        """
        if jobs:
            for job in jobs:
                self.append_job(job, dep_prefix)

    def run(self):
        if self.local:
            for job in self.pipe:
                task = Task()
                task.cmd = [job['command']]
                if 'arg' in job and job['arg']:
                    task.cmd.extend(job['arg'])
                print("Running local command:", task.cmd)
                run_tasks([task], log_dir=get_logs_dir(self.wdir, job['name'], create=True))
        else:
            assert self.name is not None
            pipedream.api.schedule(self.pipe)


def run_tasks(tasks, log_dir=None, num_retry=0):
    """Run the tasks, logging their output, and retrying them if they fail.

    'tasks' is an array of commands, each command ready to be passed to subprocess.
    'log_dir' is where the logs will be saved (assumed to exist and be writable).
    'num_retry' the number of times we will try re-running the task if it fails.

    Tasks that fail are printed on stdout (which allows finding what log files
    to inspect). If a task fails we will abort.
    """
    failed_tasks = []
    count = 0
    while len(tasks) > 0:
        task = tasks[0]
        tasks = tasks[1:]
        log_fd, log_fn = tempfile.mkstemp(dir=log_dir, suffix=".log")
        task.logs.append(log_fn)
        with os.fdopen(log_fd, "w") as log_file:
            try:
                my_env = os.environ.copy()
                my_env["DISPLAY"] = ":0"
                print('Running command {}: {}'.format(count, task.cmd))
                count += 1
                subprocess.check_call(task.cmd, env=my_env, stderr=log_file, stdout=log_file)
            except subprocess.CalledProcessError as ex:
                log_file.write("EXIT FAILURE\n")
                log_file.write(str(ex))
                if len(task.logs) > num_retry:
                    failed_tasks.append(task)
                else:
                    # requeue the task
                    tasks.append(task)
            except:
                logging.exception("Unexpected error while launching: " + str(task.cmd))
                sys.exit(1)

    if len(failed_tasks) > 0:
        for t in failed_tasks:
            print("ERROR: " + " ".join(t.cmd))
            print("logs: " + str(t.logs))
            print("")
            for f in t.files_to_clean_on_failure:
                if os.path.isfile(f):
                    # CLM-7583: this command fails if the device runs out of disk space.
                    os.remove(f)
        sys.exit(1)

def copy_if_source_file_exists(source_file, destination_dir):
    """
    Copy file from source to destination if the file exists.
    """
    if os.path.isfile(source_file):
        shutil.copy(source_file, destination_dir)


def get_lmdb_folder_list(lmdb_pbtxt):
    """
    Return the list of lmdb folders as specified in the lmdb_pbtxt proto file.
    A boolean flag may also be provided  along with the folder path, indicating
    whether or not the associated tiles are being newly generated.
    """
    lmdb_dir_list_proto = mapper_log_pb2.LmdbFolderList()
    lmdb_dir_list = []
    generate_lmdb_dir_list = None
    base.proto.ReadProtoAsText(lmdb_pbtxt, lmdb_dir_list_proto)

    for lmdb_dir in lmdb_dir_list_proto.lmdb_dir:
        lmdb_dir_list.append(lmdb_dir)

    if len(lmdb_dir_list_proto.skip_tile_generation):
      skip_tile_dir_list = []
      for skip in lmdb_dir_list_proto.skip_tile_generation:
        skip_tile_dir_list.append(skip)

    return lmdb_dir_list, skip_tile_dir_list


# Functions to centralize names of files and folders used in the pipeline

def check_dir(d, required = False, create = False):
    if not os.path.exists(d) or not os.path.isdir(d):
        if required:
            raise IOError("Required directory '{}' does not exist".format(d))
        if create:
            file_utils.mkdir_p_public(d)
    return d

def check_file(f, required = False):
    if required and (not os.path.exists(f) or not os.path.isfile(f)):
        raise IOError("Required file '{}' does not exist".format(f))
    return f

def get_input_data_dir(wdir, required = False, create = False):
    return check_dir(os.path.join(wdir, 'log_data'), required, create)

def get_vehicle_params_file(wdir):
    return check_file(os.path.join(get_input_data_dir(wdir), "params.yaml"), required = True)

def get_logs_dir(wdir, task_name = None, required = False, create = False):
    d = os.path.join(wdir, 'logs')
    if task_name:
        d = os.path.join(d, task_name)
    return check_dir(d, required, create)

def read_tile_set_proto(filepath):
    tile_set = tile_info_pb2.TileSet()
    f = check_file(filepath, required = True)
    base.proto.ReadProto(f, tile_set)
    return tile_set

def get_lmdb_list_file(lmdb_dir):
    return os.path.join(lmdb_dir, "lmdb_list.pbtxt")

def get_tile_list_file(lmdb_dir):
    return os.path.join(lmdb_dir, "tile_list.pbtxt")

def get_raw_3d_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'raw', 'Tiles-{}-25-10'.format(multi_log_proto.utm_zone)), required, create)

def get_raw_alt_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'raw', 'Altitudes-{}-25-50'.format(multi_log_proto.utm_zone)), required, create)

def get_raw_alt_mesh_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'raw', 'AltitudesMesh-{}-25-50'.format(multi_log_proto.utm_zone)), required, create)

def get_simplified_3d_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'simplified', 'Tiles-{}-25-10'.format(multi_log_proto.utm_zone)), required, create)

def get_simplified_alt_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'simplified', 'Altitudes-{}-25-50'.format(multi_log_proto.utm_zone)), required, create)

def get_simplified_alt_mesh_tile_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir,
        'simplified', 'AltitudesMesh-{}-25-50'.format(multi_log_proto.utm_zone)), required, create)

def get_images_dir(multi_log_proto, required = False, create = False):
    return check_dir(os.path.join(multi_log_proto.wdir, 'images'), required, create)

def get_intensity_tiles_dir(multi_log_proto, required=False, create=False):
    return check_dir(os.path.join(multi_log_proto.wdir, 'images', 'laser'), required, create)

def get_3d_tile_name(x, y):
    return 'tile-{}-{}.ply'.format(x, y)

def get_lmdb_tile_name(x, y):
    return 'tile-{}-{}.lmdb'.format(x, y)

def get_alt_tile_name(x, y):
    return 'altitude-{}-{}.pb'.format(x, y)

def get_lmdb_folder_list_pbtxt_file(wdir):
    return os.path.join(wdir, 'lmdb_folder_list.pbtxt')

def get_bin_tiles_dir(multi_log_proto, required=False, create=False):
    raw_tile_dir = get_raw_3d_tile_dir(multi_log_proto, False, False)
    bin_tile_dir = os.path.join(raw_tile_dir,
            'bin_tiles', 'Tiles-{}-25-10'.format(multi_log_proto.utm_zone))
    if not check_dir(raw_tile_dir, required, create):
        return False

    return check_dir(bin_tile_dir, required, create)

def get_ortho_tiles_dirs_and_flags(wdir, required=False, create=False):

    dirs_and_flags = dict()
    normals_dir = check_dir(os.path.join(wdir, 'images', 'mesh_normals'), required, create)
    if not normals_dir:
        return False
    dirs_and_flags['normals'] = normals_dir

    double_ground_dir = check_dir(os.path.join(wdir, 'images', 'double_ground'), required, create)
    if not double_ground_dir:
        return False
    dirs_and_flags['double_ground'] = double_ground_dir

    return dirs_and_flags

def get_map_overlay_dir(wdir, required = False, create = False):
    return check_dir(os.path.join(wdir, 'map_overlay'), required, create)

def get_mesh_quality_dir(multi_log_proto, required = False, create = False):
    raw_tile_dir = get_raw_3d_tile_dir(multi_log_proto, False, False)
    return check_dir(os.path.join(raw_tile_dir, 'mesh_validation'), required, create)

def get_tile_set_path(multi_log_proto):
    mesh_dir = get_raw_3d_tile_dir(multi_log_proto, create = True)
    return os.path.join(mesh_dir, 'tile_set.pbtxt')

def get_mesh_paths_subset(info_proto_path, task_idx, task_count):
    ''' Iterate through the info proto list and get a subset for parallel processing '''
    info_proto = map_tile_info_pb2.MapTileInfoList()
    base.proto.ReadProto(info_proto_path, info_proto)

    mesh_tile_sublist = list()
    # Increment by the number of tasks in the parallel processing pool to get
    # an even distribution of tiles
    for info_idx in range(task_idx, len(info_proto.map_tile_info), task_count):
        mesh_tile_sublist.append(info_proto.map_tile_info[info_idx].file_path)

    return mesh_tile_sublist

def copy_tiles(mesh_tile_sublist, output_folder):
    ''' Copy the tiles in the list into the new folder. Returns list of new paths. '''
    new_path_list = list()
    for tile_path in mesh_tile_sublist:
        fname = os.path.basename(tile_path)
        new_path = os.path.join(output_folder, fname)
        shutil.copyfile(tile_path, new_path)

        new_path_list.append(new_path)

    return new_path_list
