Source code for lore_explainer.lorem

import numpy as np

import itertools
from functools import partial

from scipy.spatial.distance import cdist

from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score

from .rule import Rule, compact_premises

from lore_explainer.explanation import Explanation, MultilabelExplanation
from lore_explainer.decision_tree import learn_local_decision_tree
from lore_explainer.neighgen import RandomGenerator, GeneticGenerator, RandomGeneticGenerator, ClosestInstancesGenerator
from lore_explainer.neighgen import GeneticProbaGenerator, RandomGeneticProbaGenerator
from lore_explainer.rule import get_rule, get_counterfactual_rules
from lore_explainer.util import calculate_feature_values, neuclidean, multilabel2str, multi_dt_predict


[docs]def default_kernel(d, kernel_width): return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
# LOcal Rule-based Explanation Method
[docs]class LOREM(object): def __init__(self, K, bb_predict, feature_names, class_name, class_values, numeric_columns, features_map, neigh_type='genetic', categorical_use_prob=True, continuous_fun_estimation=False, size=1000, ocr=0.1, multi_label=False, one_vs_rest=False, filter_crules=True, init_ngb_fn=True, kernel_width=None, kernel=None, random_state=None, verbose=False, **kwargs): self.random_state = random_state self.bb_predict = bb_predict self.K = K self.class_name = class_name self.feature_names = feature_names self.class_values = class_values self.numeric_columns = numeric_columns self.features_map = features_map self.neigh_type = neigh_type self.multi_label = multi_label self.one_vs_rest = one_vs_rest self.filter_crules = self.bb_predict if filter_crules else None self.verbose = verbose self.features_map_inv = None if self.features_map: self.features_map_inv = dict() for idx, idx_dict in self.features_map.items(): for k, v in idx_dict.items(): self.features_map_inv[v] = idx kernel_width = np.sqrt(len(self.feature_names)) * .75 if kernel_width is None else kernel_width self.kernel_width = float(kernel_width) kernel = default_kernel if kernel is None else kernel self.kernel = partial(kernel, kernel_width=kernel_width) np.random.seed(self.random_state) if init_ngb_fn: self.__init_neighbor_fn(ocr, categorical_use_prob, continuous_fun_estimation, size, kwargs)
[docs] def explain_instance(self, x, samples=1000, use_weights=True, metric=neuclidean): if isinstance(samples, int): if self.verbose: print('generating neighborhood - %s' % self.neigh_type) Z = self.neighgen_fn(x, samples) else: Z = samples Yb = self.bb_predict(Z) if self.multi_label: Z = np.array([z for z, y in zip(Z, Yb) if np.sum(y) > 0]) Yb = self.bb_predict(Z) if self.verbose: if not self.multi_label: neigh_class, neigh_counts = np.unique(Yb, return_counts=True) neigh_class_counts = {self.class_values[k]: v for k, v in zip(neigh_class, neigh_counts)} else: neigh_counts = np.sum(Yb, axis=0) neigh_class_counts = {self.class_values[k]: v for k, v in enumerate(neigh_counts)} print('synthetic neighborhood class counts %s' % neigh_class_counts) weights = None if not use_weights else self.__calculate_weights__(Z, metric) if self.one_vs_rest and self.multi_label: exp = self.__explain_tabular_instance_multiple_tree(x, Z, Yb, weights) else: # binary, multiclass, multilabel all together exp = self.__explain_tabular_instance_single_tree(x, Z, Yb, weights) return exp
def __calculate_weights__(self, Z, metric): if np.max(Z) != 1 and np.min(Z) != 0: Zn = (Z - np.min(Z)) / (np.max(Z) - np.min(Z)) distances = cdist(Zn, Zn[0].reshape(1, -1), metric=metric).ravel() else: distances = cdist(Z, Z[0].reshape(1, -1), metric=metric).ravel() weights = self.kernel(distances) return weights def __explain_tabular_instance_single_tree(self, x, Z, Yb, weights): if self.verbose: print('learning local decision tree') idx_train = len(Z) - int(len(Z) * 0.05) dt = learn_local_decision_tree(Z[:idx_train], Yb[:idx_train], weights[:idx_train], self.class_values, self.multi_label, self.one_vs_rest, prune_tree=False) Yc = dt.predict(Z) fidelity = dt.score(Z, Yb, sample_weight=weights) if self.verbose: print('retrieving explanation') rule = get_rule(x, dt, self.feature_names, self.class_name, self.class_values, self.numeric_columns, self.multi_label) crules, deltas = get_counterfactual_rules(x, Yc[0], dt, Z, Yc, self.feature_names, self.class_name, self.class_values, self.numeric_columns, self.features_map, self.features_map_inv, self.filter_crules, self.multi_label) exp = Explanation() exp.bb_pred = Yb[0] exp.dt_pred = Yc[0] exp.rule = rule exp.crules = crules exp.deltas = deltas exp.dt = dt exp.fidelity = fidelity return exp def __explain_tabular_instance_multiple_tree(self, x, Z, Yb, weights): dt_list = list() premises = list() rule_list = list() crules_list = list() deltas_list = list() nbr_labels = len(self.class_name) if self.verbose: print('learning %s local decision trees' % nbr_labels) for l in range(nbr_labels): if np.sum(Yb[:, l]) == 0 or np.sum(Yb[:, l]) == len(Yb): outcome = 0 if np.sum(Yb[:, l]) == 0 else 1 rule = Rule([], outcome, [0, 1]) crules, deltas = list(), list() dt = DummyClassifier() dt.fit(np.zeros(Z.shape[1]).reshape(1, -1), np.array([outcome])) else: idx_train = len(Z) - int(len(Z) * 0.05) dt = learn_local_decision_tree(Z[:idx_train], Yb[:idx_train, l], weights[:idx_train], self.class_values, self.multi_label, self.one_vs_rest, prune_tree=False) Yc = dt.predict(Z) class_values = [0, 1] rule = get_rule(x, dt, self.feature_names, self.class_name[l], class_values, self.numeric_columns, multi_label=False) crules, deltas = get_counterfactual_rules(x, Yc[0], dt, Z, Yc, self.feature_names, self.class_name[l], class_values, self.numeric_columns, self.features_map, self.features_map_inv, self.filter_crules, multi_label=False) dt_list.append(dt) rule_list.append(rule) premises.extend(rule.premises) crules_list.append(crules) deltas_list.append(deltas) if self.verbose: print('retrieving explanation') Yc = multi_dt_predict(Z, dt_list) fidelity = accuracy_score(Yb, Yc, sample_weight=weights) premises = compact_premises(premises) dt_outcome = multi_dt_predict(x.reshape(1, -1), dt_list)[0] cons = multilabel2str(dt_outcome, self.class_values) rule = Rule(premises, cons, self.class_name) exp = MultilabelExplanation() exp.bb_pred = Yb[0] exp.dt_pred = Yc[0] exp.rule = rule exp.crules = list(itertools.chain.from_iterable(crules_list)) exp.deltas = list(itertools.chain.from_iterable(deltas_list)) exp.dt = dt_list exp.fidelity = fidelity exp.rule_list = rule_list exp.crules_list = crules_list exp.deltas_list = deltas_list return exp def __init_neighbor_fn(self, ocr, categorical_use_prob, continuous_fun_estimation, size, kwargs): neighgen = None numeric_columns_index = [i for i, c in enumerate(self.feature_names) if c in self.numeric_columns] self.feature_values = None if self.neigh_type in ['random', 'genetic', 'rndgen', 'geneticp', 'rndgenp']: if self.verbose: print('calculating feature values') self.feature_values = calculate_feature_values(self.K, numeric_columns_index, categorical_use_prob=categorical_use_prob, continuous_fun_estimation=continuous_fun_estimation, size=size) nbr_features = len(self.feature_names) nbr_real_features = self.K.shape[1] if self.neigh_type in ['genetic', 'rndgen', 'geneticp', 'rndgenp']: alpha1 = kwargs.get('alpha1', 0.5) alpha2 = kwargs.get('alpha2', 0.5) metric = kwargs.get('metric', neuclidean) ngen = kwargs.get('ngen', 10) mutpb = kwargs.get('mutpb', 0.5) cxpb = kwargs.get('cxpb', 0.7) tournsize = kwargs.get('tournsize', 3) halloffame_ratio = kwargs.get('halloffame_ratio', 0.1) random_seed = self.random_state if self.neigh_type == 'genetic': neighgen = GeneticGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1, alpha2=alpha2, metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb, tournsize=tournsize, halloffame_ratio=halloffame_ratio, random_seed=random_seed, verbose=self.verbose) elif self.neigh_type == 'rndgen': neighgen = RandomGeneticGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1, alpha2=alpha2, metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb, tournsize=tournsize, halloffame_ratio=halloffame_ratio, random_seed=random_seed, verbose=self.verbose) elif self.neigh_type == 'geneticp': bb_predict_proba = kwargs.get('bb_predict_proba', None) neighgen = GeneticProbaGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1, alpha2=alpha2, metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb, tournsize=tournsize, halloffame_ratio=halloffame_ratio, bb_predict_proba=bb_predict_proba, random_seed=random_seed, verbose=self.verbose) elif self.neigh_type == 'rndgenp': bb_predict_proba = kwargs.get('bb_predict_proba', None) neighgen = RandomGeneticProbaGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1, alpha2=alpha2, metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb, tournsize=tournsize, halloffame_ratio=halloffame_ratio, bb_predict_proba=bb_predict_proba, random_seed=random_seed, verbose=self.verbose) elif self.neigh_type == 'random': neighgen = RandomGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr) elif self.neigh_type == 'closest': Kc = kwargs.get('Kc', None) k = kwargs.get('k', None) type = kwargs.get('core_neigh_type', 'simple') alphaf = kwargs.get('alphaf', 0.5) alphal = kwargs.get('alphal', 0.5) metric_features = kwargs.get('metric_features', neuclidean) metric_labels = kwargs.get('metric_labels', neuclidean) neighgen = ClosestInstancesGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features, nbr_real_features, numeric_columns_index, ocr=ocr, K=Kc, rK=self.K, k=k, core_neigh_type=type, alphaf=alphaf, alphal=alphal, metric_features=metric_features, metric_labels=metric_labels, categorical_use_prob=categorical_use_prob, continuous_fun_estimation=continuous_fun_estimation, size=size, verbose=self.verbose) else: print('unknown neighborhood generator') raise Exception self.neighgen_fn = neighgen.generate