import copy
from collections import defaultdict
from typing import Optional, Callable, Union, List, Dict, Any, Tuple
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from target_extraction.data_types import TargetTextCollection, TargetText
from target_extraction.analysis import sentiment_metrics
from target_extraction.analysis.sentiment_error_analysis import (distinct_sentiment,
                                                                 swap_list_dimensions,
                                                                 reduce_collection_by_key_occurrence,
                                                                 swap_and_reduce)
from target_extraction.analysis.statistical_analysis import one_tailed_p_value
[docs]def metric_df(target_collection: TargetTextCollection, 
              metric_function: Callable[[TargetTextCollection, str, str, bool, 
                                         bool, Optional[int], bool], 
                                        Union[float, List[float]]],
              true_sentiment_key: str, predicted_sentiment_keys: List[str], 
              average: bool, array_scores: bool, 
              assert_number_labels: Optional[int] = None,
              ignore_label_differences: bool = True, 
              metric_name: str = 'metric', 
              include_run_number: bool = False) -> pd.DataFrame:
    '''
    :param target_collection: Collection of targets that have true and predicted 
                              sentiment values.
    :param metric_function: A metric function from 
                            :py:func:`target_extraction.analysis.sentiment_metrics`
    :param true_sentiment_key: Key in the `target_collection` targets that 
                               contains the true sentiment scores for each 
                               target in the TargetTextCollection
    :param predicted_sentiment_keys: The name of the predicted sentiment keys 
                                     within the TargetTextCollection for 
                                     which the metric function should be applied
                                     to.
    :param average: For each predicted sentiment key it will return the 
                    average metric score across the *N* predictions made for 
                    each predicted sentiment key.
    :param array_scores: If average is False then this will return all of the 
                         *N* model runs metric scores.
    :param assert_number_labels: Whether or not to assert this many number of unique  
                                 labels must exist in the true sentiment key. 
                                 If this is None then the assertion is not raised.
    :param ignore_label_differences: If True then the ValueError will not be 
                                     raised if the predicted sentiment values 
                                     are not in the true sentiment values.
    :param metric_name: The name to give to the metric value column.
    :param include_run_number: If `array_scores` is True then this will add an 
                               extra column to the returned dataframe (`run number`) 
                               which will include the model run number. This can 
                               be used to uniquely identify each row when combined 
                               with the `prediction key` string.
    :returns: A pandas DataFrame with two columns: 1. The prediction 
              key string 2. The metric value. Where the number of rows in the 
              DataFrame is either Number of prediction keys when `average` is 
              `True` or Number of prediction keys * Number of model runs when 
              `array_scores` is `True`
    :raises ValueError: If `include_run_number` is True and `array_scores` is 
                        False.
    '''
    if include_run_number is not None and not array_scores:
        raise ValueError('Can only have `include_run_number` as True if '
                         '`array_scores` is also True')
    df_predicted_keys = []
    df_metric_values = []
    df_run_numbers = []
    for predicted_sentiment_key in predicted_sentiment_keys:
        metric_scroes = metric_function(target_collection, true_sentiment_key, 
                                        predicted_sentiment_key, average=average, 
                                        array_scores=array_scores, 
                                        assert_number_labels=assert_number_labels,
                                        ignore_label_differences=ignore_label_differences)
        if isinstance(metric_scroes, list):
            for metric_index, metric_score in enumerate(metric_scroes):
                df_metric_values.append(metric_score)
                df_predicted_keys.append(predicted_sentiment_key)
                df_run_numbers.append(metric_index)
        else:
            df_metric_values.append(metric_scroes)
            df_predicted_keys.append(predicted_sentiment_key)
    df_dict = {f'{metric_name}': df_metric_values, 
               'prediction key': df_predicted_keys}
    if include_run_number:
        df_dict['run number'] = df_run_numbers
    return pd.DataFrame(df_dict) 
[docs]def metric_p_values(data_split_df: pd.DataFrame, better_split: str, 
                    compare_splits: List[str], datasets: List[str], 
                    metric_names_assume_normals: List[Tuple[str, bool]],
                    better_and_compare_column_name: str = 'Model'
                    ) -> pd.DataFrame:
    '''
    :param data_split_df: The DataFrame that contains at least the following 
                            columns: 1. value for `better_and_compare_column_name`,
                            2. `Dataset`, and 3. all `metric name` 
    :param better_split: The name of the model you are testing if it is better
                        than all other models in the `compare_splits`
    :param compare_splits: The name of the models you assume are no different 
                            in score to the `better_split` model.
    :param datasets: Datasets to test the hypothesis on.
    :param metric_names_assume_normals: A list of Tuples that contain 
                                        (metric name, assumed to be normal)
                                        where the `assumed to be normal` is False 
                                        or True based on whether the metric scores 
                                        from `metric name` column can be assumed to be 
                                        normal or not. e.g. [(`Accuracy`, True)]
    :param better_and_compare_column_name: The column that contains the 
                                            `better_split` and `compare_splits` 
                                            values.
    :returns: A DataFrame containing the following columns: 1. Metric, 2. Dataset,
              3. P-Value, 4. Compared {better_and_compare_column_name}, and 5.
              Better {better_and_compare_column_name}. Where it tests that one 
              Model is statistically better than the compare models on each 
              given dataset for each metric given.
    '''
    temp_df = data_split_df.copy(deep=True)
    better_df = temp_df[temp_df[f'{better_and_compare_column_name}']==f'{better_split}']
    
    compare_values = []
    better_values = []
    dataset_names = []
    metric_names = []
    p_values = []
    for compare_split in compare_splits:
        compare_df = temp_df[temp_df[f'{better_and_compare_column_name}']==f'{compare_split}']
        for dataset in datasets:
            better_dataset_df = better_df[better_df['Dataset']==dataset]
            compare_dataset_df = compare_df[compare_df['Dataset']==dataset]
            for metric_name, assume_normal in metric_names_assume_normals:
                better_scores = better_dataset_df[f'{metric_name}']
                compare_scores = compare_dataset_df[f'{metric_name}']
                p_value = one_tailed_p_value(better_scores, compare_scores, 
                                             assume_normal=assume_normal)
                
                p_values.append(p_value)
                metric_names.append(metric_name)
                dataset_names.append(dataset)
                compare_values.append(compare_split)
                better_values.append(better_split)
        
    return pd.DataFrame({f'Compared {better_and_compare_column_name}': compare_values, 
                         'Metric': metric_names, 
                         'Dataset': dataset_names, 
                         'P-Value': p_values,
                         f'Better {better_and_compare_column_name}': better_values 
                        }) 
[docs]def combine_metrics(metric_df: pd.DataFrame, other_metric_df: pd.DataFrame, 
                    other_metric_name: str) -> pd.DataFrame:
    '''
    :param metric_df: DataFrame that contains all the metrics to be kept
    :param other_metric_df: Contains metric scores that are to be added to a copy 
                            of `metric_df`
    :param other_metric_name: Name of the column of the metric scores to be copied
                              from `other_metric_df`
    :returns: A copy of the `metric_df` with a new column `other_metric_name`
            that contains the other metric scores.
    :Note: This assumes that the two dataframes come from 
           :py:func:`target_extraction.analysis.util.metric_df` with the argument 
           `include_run_number` as True. This is due to the columns used to 
           combine the metric scores are `prediction key` and `run number`.
    :raises KeyError: If `prediction key` and `run number` are not columns 
                      within `metric_df` and `other_metric_df`
    '''
    index_keys = ['prediction key', 'run number']
    for df_name, df in [('metric_df', metric_df), ('other_metric_df', other_metric_df)]:
        df_columns = df.columns
        for index_key in index_keys:
                if index_key not in df_columns:
                    raise KeyError(f'The following column {index_key} does not'
                                   f'exist in {df_name} dataframe. The following'
                                   f' columns do exist {df_columns}')
    
    new_metric_df = metric_df.copy(deep=True)
    new_metric_df = new_metric_df.set_index(index_keys)
    other_metric_scores = other_metric_df.set_index(index_keys)[other_metric_name]
    new_metric_df[other_metric_name] = other_metric_scores
    new_metric_df = new_metric_df.reset_index()
    return new_metric_df 
[docs]def overall_metric_results(collection: TargetTextCollection, 
                           prediction_keys: Optional[List[str]] = None,
                           true_sentiment_key: str = 'target_sentiments',
                           strict_accuracy_metrics: bool = False
                           ) -> pd.DataFrame:
    '''
    :param collection: Dataset that contains all of the results. Furthermore it 
                       should have the name attribute as something meaningful 
                       e.g. `Laptop` for the Laptop dataset.
    :param prediction_keys: A list of prediction keys that you want the results 
                            for. If None then it will get all of the prediction 
                            keys from 
                            `collection.metadatap['predicted_target_sentiment_key']`.
    :param true_sentiment_key: Key in the `target_collection` targets that 
                               contains the true sentiment scores for each 
                               target in the TargetTextCollection.
    :param strict_accuracy_metrics: If this is True the dataframe will also 
                                    contain three additional columns: 'STAC',
                                    'STAC 1', and 'STAC Multi'. Where 'STAC'
                                    is the Strict Target Accuracy (STAC) on the 
                                    whole dataset, 'STAC 1' and 'STAC Multi' is 
                                    the STAC metric performed on the subset of 
                                    the dataset that contain either one unique 
                                    sentiment or more than one unique sentiment 
                                    per text respectively.
    :returns: A pandas dataframe with the following columns: `['prediction key', 
              'run number', 'Accuracy', 'Macro F1', 'Dataset']`. The `Dataset`
              column will contain one unique value and that will come from 
              the `name` attribute of the `collection`. The DataFrame will 
              also contain columns and values from the associated metadata see
              :py:func:`add_metadata_to_df` for more details.
    '''
    
    if prediction_keys is None:
        prediction_keys = list(collection.metadata['predicted_target_sentiment_key'].keys())
    acc_df = metric_df(collection, sentiment_metrics.accuracy, 
                       true_sentiment_key, prediction_keys,
                       array_scores=True, assert_number_labels=3, 
                       metric_name='Accuracy', average=False, include_run_number=True)
    acc_df = add_metadata_to_df(acc_df, collection, 'predicted_target_sentiment_key')
    f1_df = metric_df(collection, sentiment_metrics.macro_f1, 
                      true_sentiment_key, prediction_keys,
                      array_scores=True, assert_number_labels=3, 
                      metric_name='Macro F1', average=False, include_run_number=True)
    combined_df = combine_metrics(acc_df, f1_df, 'Macro F1')
    if strict_accuracy_metrics:
        collection_copy = copy.deepcopy(collection)
        collection_copy = distinct_sentiment(collection_copy, separate_labels=True, 
                                             true_sentiment_key=true_sentiment_key)
        stac_multi_collection = copy.deepcopy(collection_copy)
        stac_multi_collection = swap_and_reduce(stac_multi_collection, 
                                               ['distinct_sentiment_2', 'distinct_sentiment_3'],
                                               true_sentiment_key, 
                                               prediction_keys)
        stac_multi = metric_df(stac_multi_collection, sentiment_metrics.strict_text_accuracy, 
                              true_sentiment_key, prediction_keys,
                              array_scores=True, assert_number_labels=3, 
                               metric_name='STAC Multi', average=False, 
                               include_run_number=True)
        combined_df = combine_metrics(combined_df, stac_multi, 'STAC Multi')
        del stac_multi_collection
        stac_1_collection = copy.deepcopy(collection_copy)
        stac_1_collection = swap_and_reduce(stac_1_collection, 
                                           'distinct_sentiment_1',
                                            true_sentiment_key, 
                                            prediction_keys)
        stac_1 = metric_df(stac_1_collection, sentiment_metrics.strict_text_accuracy, 
                           true_sentiment_key, prediction_keys,
                           array_scores=True, assert_number_labels=3, 
                           metric_name='STAC 1', average=False, 
                           include_run_number=True)
        combined_df = combine_metrics(combined_df, stac_1, 'STAC 1')
        del stac_1_collection
        del collection_copy
        stac = metric_df(collection, sentiment_metrics.strict_text_accuracy, 
                         true_sentiment_key, prediction_keys,
                         array_scores=True, assert_number_labels=3, 
                         metric_name='STAC', average=False, 
                         include_run_number=True)
        combined_df = combine_metrics(combined_df, stac, 'STAC')
    combined_df['Dataset'] = [collection.name] * combined_df.shape[0]
    return combined_df 
[docs]def plot_error_subsets(metric_df: pd.DataFrame, df_column_name: str, 
                       df_row_name: str, df_x_name: str, df_y_name: str,
                       df_hue_name: str = 'Model', 
                       seaborn_plot_name: str = 'pointplot',
                       seaborn_kwargs: Optional[Dict[str, Any]] = None,
                       legend_column: Optional[int] = 0,
                       figsize: Optional[Tuple[float, float]] = None,
                       legend_bbox_to_anchor: Tuple[float, float] = (-0.13, 1.1),
                       fontsize: int = 14, legend_fontsize: int = 10,
                       tick_font_size: int = 12, 
                       title_on_every_plot: bool = False,
                       df_overall_metric: Optional[str] = None,
                       overall_seaborn_plot_name: Optional[str] = None,
                       overall_seaborn_kwargs: Optional[Dict[str, Any]] = None,
                       df_dataset_size: Optional[str] = None,
                       dataset_h_line_offset: float = 0.2,
                       dataset_h_line_color: str = 'k',
                       h_line_legend_name: str = 'Dataset Size (Number of Samples)',
                       h_line_legend_bbox_to_anchor: Optional[Tuple[float, float]] = None,
                       dataset_y_label: str = 'Dataset Size\n(Number of Samples)',
                       gridspec_kw: Optional[Dict[str, Any]] = None,
                       row_order: Optional[List[Any]] = None,
                       column_order: Optional[List[Any]] = None
                       ) -> Tuple[matplotlib.figure.Figure, 
                                  List[List[matplotlib.axes.Axes]]]:
    '''
    This function is named what it is as it is a good way to visualise the 
    different error subsets and thus error splits after running different 
    error functions from 
    :py:func`target_extraction.analysis.sentiment_error_analysis.error_analysis_wrapper`
    and further more if you are exploring them over different datasets. 
    To create a graph with these different error analysis subsets, Models, and datasets 
    the following column and row names may be useful: `df_column_name` = `Dataset`,
    `df_row_name` = `Error Split`, `df_x_name` = `Error Subset`, `df_y_name` 
    = `Accuracy (%)`, and `df_hue_name` = `Model`.
    :param metric_df: A DataFrame that will 
    :param df_column_name: Name of the column in `metric_df` that will be used 
                           to determine the categorical variables to facet the 
                           column part of the returned figure
    :param df_row_name: Name of the column in `metric_df` that will be used 
                        to determine the categorical variables to facet the 
                        row part of the returned figure
    :param df_x_name: Name of the column in `metric_df` that will be used to 
                      represent the X-axis in the figure.
    :param df_y_name: Name of the column in `metric_df` that will be used to 
                      represent the Y-axis in the figure.
    :param df_hue_name: Name of the column in `metric_df` that will be used to 
                        represent the hue in the figure
    :param seaborn_plot_name: Name of the seaborn plotting function to use as 
                              the plots within the figure
    :param seaborn_kwargs: The key word arguments to give to the seaborn 
                           plotting function.
    :param legend_column: Which column in the figure the legend should be 
                          associated too. The row the legend is associated 
                          with is fixed at row 0.
    :param figsize: Size of the figure, this is passed to the 
                    :py:func:`matplotlib.pyplot.subplots` as an argument.
    :param legend_bbox_to_anchor: Where the legend box should be within the 
                                  figure. This is passed as the `bbox_to_anchor`
                                  argument to 
                                  :py:func:`matplotlib.pyplot.Axes.legend`
    :param fontsize: Size of the font for the title, y-axis label, and 
                     x-axis label.
    :param legend_fontsize: Size of the font for the legend.
    :param tick_font_size: Size of the font on the y and x axis ticks.
    :param title_on_every_plot: Whether or not to have the title above every 
                                plot in the grid or just over the top row 
                                of plots.
    :param df_overall_metric: Name of the column in `metric_df` that stores 
                              the overall metric score for the entire dataset 
                              and not just the `subsets`.
    :param overall_seaborn_plot_name: Same as the `seaborn_plot_name` but for 
                                      plotting the overall metric
    :param overall_seaborn_kwargs: Same as the `seaborn_kwargs` but for the 
                                   overall metric plot.
    :param df_dataset_size: Name of the column in `metric_df` that stores 
                            the dataset size for one of the X-axis. If 
                            this is given it will create h_lines for each 
                            X-axis representing the dataset size
    :param dataset_h_line_offset: +/- offsets indicating the length of each 
                                  hline
    :param dataset_h_line_color: Color of the hline
    :param h_line_legend_name: Name to give to the h_line legend.
    :param h_line_legend_bbox_to_anchor: Where the h line legend box should be within the 
                                         figure. This is passed as the `bbox_to_anchor`
                                         argument to 
                                         :py:func:`matplotlib.pyplot.Axes.legend`
    :param dataset_y_label: The Y-Label for the right hand side Y-axis.
    :param gridspec_kw: :py:func:`matplotlib.pyplot.subplots` `gridspec_kw` argument
    :param row_order: A list of all unique `df_row_name` values in the order 
                      the rows should appear in.
    :param column_order: A list of all unique `df_column_name` values in the order 
                         the columns should appear in.
    :returns: A tuple of 1. The figure  2. The associated axes within the 
              figure. The figure will contain N x M plots where N is the number 
              of unique values in the `metric_df` `df_column_name` column and 
              M is the number of unique values in the `metric_df` 
              `df_row_name` column.
    '''
    def plot_error_split(df: pd.DataFrame, 
                        error_axs: List[matplotlib.axes.Axes], 
                        column_names: List[str],
                        first_row: bool, last_row: bool,
                        number_hue_values: int = 1,
                        h_line_legend_bbox_to_anchor: Optional[Tuple[float, float]] = None) -> None:
        for col_index, column_name in enumerate(column_names):
            _df = df[df[df_column_name]==column_name]
            ax = error_axs[col_index]
            getattr(sns, seaborn_plot_name)(x=df_x_name, y=df_y_name, 
                                            hue=df_hue_name, data=_df, 
                                            ax=ax, **seaborn_kwargs)
            # Required if plotting the overall metrics
            if df_overall_metric:
                _temp_overall_df: pd.DataFrame = _df.copy(deep=True)
                _temp_overall_df = _temp_overall_df.drop(columns=df_y_name)
                _temp_overall_df = _temp_overall_df.rename(columns={df_overall_metric: df_y_name})
                getattr(sns, overall_seaborn_plot_name)(x=df_x_name, y=df_y_name, 
                                                        hue=df_hue_name, 
                                                        data=_temp_overall_df, 
                                                        ax=ax, 
                                                        **overall_seaborn_kwargs)
            
            # Y axis labelling
            row_name = _df[df_row_name].unique()
            row_name_err = ('There should only be one unique row name {row_name} '
                            f'from the row column {df_row_name}')
            assert len(row_name) == 1, row_name_err
            row_name = row_name[0]
            if col_index != 0:
                ax.set_ylabel('')
            else:
                ax.set_ylabel(f'{df_row_name}={row_name}\n{df_y_name}', 
                            fontsize=fontsize)
            # X axis labelling
            if last_row:
                ax.set_xlabel(df_x_name, fontsize=fontsize)
            else:
                ax.set_xlabel('')
            # Title
            if first_row or title_on_every_plot:
                ax.set_title(f'{df_column_name}={column_name}', fontsize=fontsize)
            # Legend
            if col_index == legend_column and first_row:
                ax.legend(bbox_to_anchor=legend_bbox_to_anchor, 
                          loc='lower left', fontsize=legend_fontsize, 
                          ncol=number_hue_values, borderaxespad=0.)
            else:
                ax.get_legend().remove()
            
            if df_dataset_size:
                dataset_ax = ax.twinx()
                # only if it is the last column
                if col_index == (len(column_names) - 1):
                    dataset_ax.set_ylabel(dataset_y_label, fontsize=fontsize)
                dataset_sizes = []
                x_values = [x_tick_label.get_text() 
                            for x_tick_label in ax.get_xticklabels()]
                for x_value in x_values:
                    dataset_size = _df.loc[_df[df_x_name]==x_value][df_dataset_size].unique()
                    dataset_size = dataset_size.tolist()
                    assert 1 == len(dataset_size)
                    dataset_sizes.append(dataset_size)
    
                for index, dataset_size in enumerate(dataset_sizes):
                    x_indexes = (index - dataset_h_line_offset, 
                                 index + dataset_h_line_offset)
                    if index == 0 and col_index == legend_column and first_row:
                        dataset_ax.hlines(dataset_size, x_indexes[0], x_indexes[1], 
                                          linestyles='dashed', color=dataset_h_line_color,
                                          label=h_line_legend_name)
                        if h_line_legend_bbox_to_anchor is None:
                            h_line_legend_bbox_0, h_line_legend_bbox_1 = legend_bbox_to_anchor
                            h_line_legend_bbox_1 = h_line_legend_bbox_1 + 0.1
                            h_line_legend_bbox_to_anchor = (h_line_legend_bbox_0, h_line_legend_bbox_1)
                        dataset_ax.legend(bbox_to_anchor=h_line_legend_bbox_to_anchor, 
                                          loc='lower left', fontsize=legend_fontsize, 
                                          borderaxespad=0.)
                    else:
                        dataset_ax.hlines(dataset_size, x_indexes[0], x_indexes[1], 
                                          linestyles='dashed', 
                                          color=dataset_h_line_color)
    plt.rc('xtick', labelsize=tick_font_size)
    plt.rc('ytick', labelsize=tick_font_size)
    # Seaborn plotting options
    if seaborn_kwargs is None and seaborn_plot_name=='pointplot':
        seaborn_kwargs = {'join': False, 'ci': 'sd', 'dodge': 0.4, 
                          'capsize': 0.05}
    elif seaborn_kwargs is None:
        seaborn_kwargs = {}
    # Ensure that all the values in hue column will always be the same
    hue_values = metric_df[df_hue_name].unique().tolist()
    number_hue_values = len(hue_values)
    palette = dict(zip(hue_values, sns.color_palette()))
    seaborn_kwargs['palette'] = palette
    # Determine the number of rows
    row_names = metric_df[df_row_name].unique().tolist()
    num_rows = len(row_names)
    if row_order is not None:
        row_order_error = (f'The `row_order` argument {row_order} should contain'
                           'the same values as the unique values in the '
                           f'`df_row_name` {df_row_name} column which are {row_names}')
        assert set(row_order) == set(row_names), row_order_error
        row_names = row_order
    # Number of columns
    column_names = metric_df[df_column_name].unique().tolist()
    number_columns = len(column_names)
    if column_order is not None:
        column_order_error = (f'The `column_order` argument {column_order} '
                              'should contain the same values as the unique '
                              f'values in the `df_column_name` {df_column_name} '
                              f'column which are {column_names}')
        assert set(column_order) == set(column_names), column_order_error
        column_names = column_order
    if figsize is None:
        length = num_rows * 4
        width = number_columns * 5
        figsize = (width, length)
    if gridspec_kw is None:
        if df_dataset_size is not None:
            gridspec_kw = {'wspace': 0.3}
        else:
            gridspec_kw = {}
    fig, axs = plt.subplots(nrows=num_rows, ncols=number_columns, 
                            figsize=figsize, gridspec_kw=gridspec_kw)
    # row 
    if num_rows > 1:
        # columns
        for row_index, row_name in enumerate(row_names):
            row_metric_df = metric_df[metric_df[df_row_name]==row_name]
            row_axs = axs[row_index]
            if row_index == (num_rows - 1):
                plot_error_split(row_metric_df, row_axs, column_names, False, 
                                 True, number_hue_values, h_line_legend_bbox_to_anchor)
            elif row_index == 0:
                plot_error_split(row_metric_df, row_axs, column_names, True, 
                                 False, number_hue_values, h_line_legend_bbox_to_anchor)
            else:
                plot_error_split(row_metric_df, row_axs, column_names, False, 
                                 False, number_hue_values, h_line_legend_bbox_to_anchor)
    # Only 1 row but multiple columns
    else:
        plot_error_split(metric_df, axs, column_names, True, True, 
                         number_hue_values, h_line_legend_bbox_to_anchor)
    return fig, axs 
[docs]def create_subset_heatmap(subset_df: pd.DataFrame, value_column: str, 
                          pivot_table_agg_func: Optional[Callable[[pd.Series], Any]] = None,
                          font_label_size: int = 10,
                          cubehelix_palette_kwargs: Optional[Dict[str, Any]] = None,
                          value_range: Optional[List[int]] = None,
                          lines: bool = True, line_color: str = 'k',
                          vertical_lines_index: Optional[List[int]] = None,
                          horizontal_lines_index: Optional[List[int]] = None,
                          ax: Optional[matplotlib.pyplot.Axes] = None,
                          heatmap_kwargs: Optional[Dict[str, Any]] = None
                          ) -> matplotlib.pyplot.Axes:
    '''
    :param subset_df: A DataFrame that contains the following columns: 
                      1. Error Split, 2. Error Subset, 3. Dataset, 
                      and 4. `value_column`
    :param value_column: The column that contains the value to be plotted in the 
                         heatmap.
    :param pivot_table_agg_func: As a pivot table is created to create the heatmap.
                                 This allows the replacement default aggregation 
                                 function (np.mean) with a custom function. The 
                                 pivot table aggregates the `value_column` by 
                                 Dataset, Error Split, and Error Subset.
    :param font_label_size: Font sizes of the labels on the returned plot
    :param cubehelix_palette_kwargs: Keywords arguments to give to the 
                                     seaborn.cubehelix_palette
                                     https://seaborn.pydata.org/generated/seaborn.cubehelix_palette.html.
                                     Default produces white to dark red.
    :param value_range: This can also be interpreted as the values allowed in 
                        the color range and should cover at least all unique 
                        values in `value_column`.
    :param lines: Whether or not lines should appear on the plot to define the 
                  different error splits.
    :param line_color: Color of the lines if the lines are to be displayed. The 
                       choice of color names can be found here: 
                       https://matplotlib.org/3.1.1/gallery/color/named_colors.html#sphx-glr-gallery-color-named-colors-py
    :param vertical_lines_index: The index of the lines in vertical/column 
                                 direction. If None default is [0,3,7,11,15,18]
    :param horizontal_lines_index: The index of the lines in vertical/column 
                                   direction. If None default is [0,1,2,3]
    :param ax: A matplotlib Axes to give to the seaborn function to plot the 
               heatmap on to.
    :param heatmap_kwargs: Keyword arguments to pass to the seaborn.heatmap 
                           function
    :returns: A heatmap where the Y-axis represents the datasets, X-axis 
              represents the Error subsets formatted when appropriate with the 
              Error split name, and the values come from the `value_column`. The 
              heatmap assumes the `value_column` contains discrete values as the 
              color bar is discrete rather than continuos. If you want a continuos 
              color bar it is recommended that you use Seaborn heatmap.
    '''
    df_copy = subset_df.copy(deep=True)
    format_error_split = lambda x: f'{x["Error Split"]}' if x["Error Split"] != "DS" else ""
    df_copy['Formatted Error Split'] =  df_copy.apply(format_error_split, 1)
    combined_split_subset = lambda x: f'{x["Formatted Error Split"]} {x["Error Subset"]}'.strip()
    df_copy['Combined Error Subset'] = df_copy.apply(combined_split_subset, 1)
    if pivot_table_agg_func is None:
        pivot_table_agg_func = np.mean
    df_copy = pd.pivot_table(data=df_copy, values=value_column, 
                             columns=['Combined Error Subset'], 
                             index=['Dataset'], aggfunc=pivot_table_agg_func)
    column_order = ['DS1', 'DS2', 'DS3', 'TSSR 1', 'TSSR 1-Multi', 'TSSR High', 
                    'TSSR Low', 'NT 1', 'NT Low', 'NT Med', 'NT High', 
                    'n-shot Zero', 'n-shot Low', 'n-shot Med', 'n-shot High', 
                    'TSR USKT', 'TSR UT', 'TSR KSKT']
    # Remove columns in the column order that do not exist as a column
    temp_column_order = []
    columns_that_exist = set(df_copy.columns.tolist())
    for column in column_order:
        if column in columns_that_exist:
            temp_column_order.append(column)
    column_order = temp_column_order
    df_copy = df_copy.reindex(column_order, axis=1)
    plt.rc('xtick',labelsize=font_label_size)
    plt.rc('ytick',labelsize=font_label_size)
    unique_values = np.unique(df_copy.values)
    if value_range is not None:
        unique_values = value_range
    num_unique_values = len(unique_values)
    color_bar_spacing = max(unique_values) / num_unique_values
    half_bar_spacing = color_bar_spacing / 2
    colorbar_values = [(i * color_bar_spacing) + half_bar_spacing 
                       for i in range(len(unique_values))]
    if cubehelix_palette_kwargs is None:
        cubehelix_palette_kwargs = {'hue': 1, 'gamma': 2.2, 'light': 1.0, 
                                    'dark': 0.7}
    cmap = sns.cubehelix_palette(n_colors=num_unique_values, 
                                 **cubehelix_palette_kwargs)
    if heatmap_kwargs is None:
        heatmap_kwargs = {}
    vmin = min(unique_values) 
    vmax = max(unique_values)
    if ax is not None:
        ax = sns.heatmap(df_copy, ax=ax, linewidths=.5, linecolor='lightgray', 
                         cmap=matplotlib.colors.ListedColormap(cmap),
                         vmin=vmin, vmax=vmax, **heatmap_kwargs)
    else:
        ax = sns.heatmap(df_copy, linewidths=.5, linecolor='lightgray', 
                         cmap=matplotlib.colors.ListedColormap(cmap),
                         vmin=vmin, vmax=vmax, **heatmap_kwargs)
    cb = ax.collections[-1].colorbar
    cb.set_ticks(colorbar_values)
    cb.set_ticklabels(unique_values)
    ax.set_xlabel('Error Subset', fontsize=font_label_size)
    ax.set_ylabel('Dataset', fontsize=font_label_size)
    if lines:
        if vertical_lines_index is None:
            vertical_lines_index = [0,3,7,11,15,18]
        ax.vlines(vertical_lines_index, colors=line_color, *ax.get_ylim())
        if horizontal_lines_index is None:
            horizontal_lines_index = [0,1,2,3]
        ax.hlines(horizontal_lines_index, colors=line_color, *ax.get_xlim())
    return ax