#!/usr/bin/python
"""
This script is meant to be run outside of bazel (as it needs access to the local
git repository). It will query the timing benchmarks table to get a sense of the
SHAs that impacted average timing the most.
"""

import subprocess32 as subprocess
import psycopg2
import psycopg2.extras
import sys
import time
import shlex
import pandas as pd

# from base.logging import zoox_logger
import logging

log = logging.getLogger(__name__)


def report_error(branch, branch_sha=None):
    msg = "No data on branch `%s`"
    if branch_sha is not None:
        msg += " (%s)."
    else:
        msg += "."

    msg += " To generate data, please run `git checkout %s && brun //vehicle/planner/metrics:benchmark_timing -- --pipedream.`"

    if branch_sha is not None:
        log.error(msg, branch, branch_sha, branch_sha)
    else:
        log.error(msg, branch, branch)


def zci_tick_all_data(shas, branch):
    """Provides all the zci_tick_all data for a given branch.

    Parameters
    ----------
    shas: list[str]
        A list of shas to look for

    Yields
    ------
    psycopg2.Row
        Presents a row of the zci_tick_all data as a dictionary.
    """
    conn = psycopg2.connect(
        "dbname='planner_benchmarks'\
        user='planner_benchmarks'\
        host='internal-postgres.cp4rxnpmuhoe.us-west-1.rds.amazonaws.com'\
        password='keijaiquahhaujahnohphasaepoogaezoda'"
    )

    cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)

    sql_query = """
        with cte_id as (
            select id, metrics_metadata.branch, metrics_metadata.git_sha
            from metrics_metadata
            where git_sha = ANY(%s)
            and branch = %s
        )
        select
            branch,
            git_sha,
            avg(stop_time - start_time) as avg_duration
        from timing_metrics
        inner join cte_id
        on cte_id.id = timing_metrics.metadata_id
        where scope_name = 'zci_tick_all'
        group by branch, git_sha;
    """

    cur.execute(sql_query, (shas, branch))

    for row in cur.fetchall():
        yield {
            "sha": row["git_sha"],
            "avg_duration": float(row["avg_duration"]) / 1000.0,
        }


def get_git_log(limit=100):
    cmd = shlex.split(
        'git log --first-parent --no-merges --pretty=format:"%H %an%x09%ad%x09%s" --graph -n {}'.format(
            limit
        )
    )
    output = subprocess.check_output(cmd)
    return output


def parse_hashes(git_log):
    return [x.split(" ")[1] for x in git_log.split("\n")]


def main():
    git_log = get_git_log(limit=200)
    hashes = parse_hashes(git_log)

    timing_info = pd.DataFrame(
        zci_tick_all_data(hashes, "nightly-develop/planner")
    )

    for line in git_log.split("\n"):
        line_data = line.split(" ")
        sha = line_data[1]

        timings = timing_info[timing_info.sha == sha]
        if timings.empty:
            line_data.insert(1, "[============]")
        else:
            line_data.insert(1, "[avg: %0.2fms]" % timings.avg_duration)

        print(" ".join(line_data))


if __name__ == "__main__":
    main()
