from bokeh.plotting import figure
from bokeh.layouts import gridplot
from bokeh.models import (
    ColumnDataSource,
    Span,
    HoverTool,
    DatetimeTickFormatter,
    Range1d,
    LinearAxis,
)
from bokeh.models.widgets import Slider
from bokeh.models.callbacks import CustomJS
from bokeh.palettes import Category10_10
from bokeh.resources import INLINE
from bokeh.embed import file_html


def boxplot(df, col, overlay_counts=False):
    # Repurposed from Bokeh boxplot example:
    # https://docs.bokeh.org/en/latest/docs/gallery/boxplot.html

    # find the quartiles and IQR for each category
    q1 = df.quantile(q=0.25)
    q2 = df.quantile(q=0.5)
    q3 = df.quantile(q=0.75)
    iqr = q3 - q1
    upper = q3 + 1.5 * iqr
    lower = q1 - 1.5 * iqr

    cats = q1.index.values

    # find the outliers for each category
    def outliers(group):
        cat = group.name
        return group[
            (group[col] > upper.loc[cat][col])
            | (group[col] < lower.loc[cat][col])
        ][col]

    out = df.apply(outliers).dropna()

    # prepare outlier data for plotting, we need coordinates for every outlier.
    if not out.empty:
        outx = []
        outy = []
        for keys in out.index:
            if isinstance(keys, int):
                outx.append(keys)
                outy.append(out.loc[keys])
            else:
                outx.append(keys[0])
                outy.append(out.loc[keys[0]].loc[keys[1]])

    p = figure(
        tools="hover",
        y_range=[0, 1.1 * df.max().max()[col]],
        toolbar_location=None,
    )

    # stems
    p.segment(cats, upper[col], cats, q3[col], line_color="black")
    p.segment(cats, lower[col], cats, q1[col], line_color="black")

    # boxes
    p.vbar(
        cats, 0.5, q2[col], q3[col], fill_color="#E08E79", line_color="black"
    )
    p.vbar(
        cats, 0.5, q1[col], q2[col], fill_color="#3B8686", line_color="black"
    )

    # whiskers (almost-0 height rects simpler than segments)
    p.segment(
        cats - 0.1, lower[col], cats + 0.1, lower[col], line_color="black"
    )
    p.segment(
        cats - 0.1, upper[col], cats + 0.1, upper[col], line_color="black"
    )

    # outliers
    if not out.empty:
        p.circle(outx, outy, size=6, color="#F38630", fill_alpha=0.6)

    if overlay_counts:
        counts = df.count()
        p.extra_y_ranges = {
            "count_y": Range1d(start=0, end=counts[col].max() * 1.5)
        }
        y2_axis = LinearAxis(y_range_name="count_y")
        y2_axis.axis_label = "# Observations"

        p.add_layout(y2_axis, "right")
        p.vbar(
            x=cats,
            top=counts[col],
            bottom=0,
            width=0.25,
            y_range_name="count_y",
            level="underlay",
            alpha=0.25,
        )

    return p


def barplot(df, col):
    counts = df.count()
    cats = counts.index.values

    p = figure(tools="", toolbar_location=None)

    # boxes
    p.vbar(x=cats, top=counts[col], bottom=0, width=0.5)

    return p
