import numpy as np
import itertools
from functools import partial
from scipy.spatial.distance import cdist
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score
from .rule import Rule, compact_premises
from lore_explainer.explanation import Explanation, MultilabelExplanation
from lore_explainer.decision_tree import learn_local_decision_tree
from lore_explainer.neighgen import RandomGenerator, GeneticGenerator, RandomGeneticGenerator, ClosestInstancesGenerator
from lore_explainer.neighgen import GeneticProbaGenerator, RandomGeneticProbaGenerator
from lore_explainer.rule import get_rule, get_counterfactual_rules
from lore_explainer.util import calculate_feature_values, neuclidean, multilabel2str, multi_dt_predict
[docs]def default_kernel(d, kernel_width):
return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
# LOcal Rule-based Explanation Method
[docs]class LOREM(object):
def __init__(self, K, bb_predict, feature_names, class_name, class_values, numeric_columns, features_map,
neigh_type='genetic', categorical_use_prob=True, continuous_fun_estimation=False,
size=1000, ocr=0.1, multi_label=False, one_vs_rest=False, filter_crules=True, init_ngb_fn=True,
kernel_width=None, kernel=None, random_state=None, verbose=False, **kwargs):
self.random_state = random_state
self.bb_predict = bb_predict
self.K = K
self.class_name = class_name
self.feature_names = feature_names
self.class_values = class_values
self.numeric_columns = numeric_columns
self.features_map = features_map
self.neigh_type = neigh_type
self.multi_label = multi_label
self.one_vs_rest = one_vs_rest
self.filter_crules = self.bb_predict if filter_crules else None
self.verbose = verbose
self.features_map_inv = None
if self.features_map:
self.features_map_inv = dict()
for idx, idx_dict in self.features_map.items():
for k, v in idx_dict.items():
self.features_map_inv[v] = idx
kernel_width = np.sqrt(len(self.feature_names)) * .75 if kernel_width is None else kernel_width
self.kernel_width = float(kernel_width)
kernel = default_kernel if kernel is None else kernel
self.kernel = partial(kernel, kernel_width=kernel_width)
np.random.seed(self.random_state)
if init_ngb_fn:
self.__init_neighbor_fn(ocr, categorical_use_prob, continuous_fun_estimation, size, kwargs)
[docs] def explain_instance(self, x, samples=1000, use_weights=True, metric=neuclidean):
if isinstance(samples, int):
if self.verbose:
print('generating neighborhood - %s' % self.neigh_type)
Z = self.neighgen_fn(x, samples)
else:
Z = samples
Yb = self.bb_predict(Z)
if self.multi_label:
Z = np.array([z for z, y in zip(Z, Yb) if np.sum(y) > 0])
Yb = self.bb_predict(Z)
if self.verbose:
if not self.multi_label:
neigh_class, neigh_counts = np.unique(Yb, return_counts=True)
neigh_class_counts = {self.class_values[k]: v for k, v in zip(neigh_class, neigh_counts)}
else:
neigh_counts = np.sum(Yb, axis=0)
neigh_class_counts = {self.class_values[k]: v for k, v in enumerate(neigh_counts)}
print('synthetic neighborhood class counts %s' % neigh_class_counts)
weights = None if not use_weights else self.__calculate_weights__(Z, metric)
if self.one_vs_rest and self.multi_label:
exp = self.__explain_tabular_instance_multiple_tree(x, Z, Yb, weights)
else: # binary, multiclass, multilabel all together
exp = self.__explain_tabular_instance_single_tree(x, Z, Yb, weights)
return exp
def __calculate_weights__(self, Z, metric):
if np.max(Z) != 1 and np.min(Z) != 0:
Zn = (Z - np.min(Z)) / (np.max(Z) - np.min(Z))
distances = cdist(Zn, Zn[0].reshape(1, -1), metric=metric).ravel()
else:
distances = cdist(Z, Z[0].reshape(1, -1), metric=metric).ravel()
weights = self.kernel(distances)
return weights
def __explain_tabular_instance_single_tree(self, x, Z, Yb, weights):
if self.verbose:
print('learning local decision tree')
idx_train = len(Z) - int(len(Z) * 0.05)
dt = learn_local_decision_tree(Z[:idx_train], Yb[:idx_train], weights[:idx_train], self.class_values, self.multi_label, self.one_vs_rest,
prune_tree=False)
Yc = dt.predict(Z)
fidelity = dt.score(Z, Yb, sample_weight=weights)
if self.verbose:
print('retrieving explanation')
rule = get_rule(x, dt, self.feature_names, self.class_name, self.class_values, self.numeric_columns,
self.multi_label)
crules, deltas = get_counterfactual_rules(x, Yc[0], dt, Z, Yc, self.feature_names, self.class_name,
self.class_values, self.numeric_columns, self.features_map,
self.features_map_inv, self.filter_crules, self.multi_label)
exp = Explanation()
exp.bb_pred = Yb[0]
exp.dt_pred = Yc[0]
exp.rule = rule
exp.crules = crules
exp.deltas = deltas
exp.dt = dt
exp.fidelity = fidelity
return exp
def __explain_tabular_instance_multiple_tree(self, x, Z, Yb, weights):
dt_list = list()
premises = list()
rule_list = list()
crules_list = list()
deltas_list = list()
nbr_labels = len(self.class_name)
if self.verbose:
print('learning %s local decision trees' % nbr_labels)
for l in range(nbr_labels):
if np.sum(Yb[:, l]) == 0 or np.sum(Yb[:, l]) == len(Yb):
outcome = 0 if np.sum(Yb[:, l]) == 0 else 1
rule = Rule([], outcome, [0, 1])
crules, deltas = list(), list()
dt = DummyClassifier()
dt.fit(np.zeros(Z.shape[1]).reshape(1, -1), np.array([outcome]))
else:
idx_train = len(Z) - int(len(Z) * 0.05)
dt = learn_local_decision_tree(Z[:idx_train], Yb[:idx_train, l], weights[:idx_train], self.class_values, self.multi_label,
self.one_vs_rest, prune_tree=False)
Yc = dt.predict(Z)
class_values = [0, 1]
rule = get_rule(x, dt, self.feature_names, self.class_name[l], class_values, self.numeric_columns,
multi_label=False)
crules, deltas = get_counterfactual_rules(x, Yc[0], dt, Z, Yc, self.feature_names,
self.class_name[l], class_values, self.numeric_columns,
self.features_map, self.features_map_inv,
self.filter_crules, multi_label=False)
dt_list.append(dt)
rule_list.append(rule)
premises.extend(rule.premises)
crules_list.append(crules)
deltas_list.append(deltas)
if self.verbose:
print('retrieving explanation')
Yc = multi_dt_predict(Z, dt_list)
fidelity = accuracy_score(Yb, Yc, sample_weight=weights)
premises = compact_premises(premises)
dt_outcome = multi_dt_predict(x.reshape(1, -1), dt_list)[0]
cons = multilabel2str(dt_outcome, self.class_values)
rule = Rule(premises, cons, self.class_name)
exp = MultilabelExplanation()
exp.bb_pred = Yb[0]
exp.dt_pred = Yc[0]
exp.rule = rule
exp.crules = list(itertools.chain.from_iterable(crules_list))
exp.deltas = list(itertools.chain.from_iterable(deltas_list))
exp.dt = dt_list
exp.fidelity = fidelity
exp.rule_list = rule_list
exp.crules_list = crules_list
exp.deltas_list = deltas_list
return exp
def __init_neighbor_fn(self, ocr, categorical_use_prob, continuous_fun_estimation, size, kwargs):
neighgen = None
numeric_columns_index = [i for i, c in enumerate(self.feature_names) if c in self.numeric_columns]
self.feature_values = None
if self.neigh_type in ['random', 'genetic', 'rndgen', 'geneticp', 'rndgenp']:
if self.verbose:
print('calculating feature values')
self.feature_values = calculate_feature_values(self.K, numeric_columns_index,
categorical_use_prob=categorical_use_prob,
continuous_fun_estimation=continuous_fun_estimation,
size=size)
nbr_features = len(self.feature_names)
nbr_real_features = self.K.shape[1]
if self.neigh_type in ['genetic', 'rndgen', 'geneticp', 'rndgenp']:
alpha1 = kwargs.get('alpha1', 0.5)
alpha2 = kwargs.get('alpha2', 0.5)
metric = kwargs.get('metric', neuclidean)
ngen = kwargs.get('ngen', 10)
mutpb = kwargs.get('mutpb', 0.5)
cxpb = kwargs.get('cxpb', 0.7)
tournsize = kwargs.get('tournsize', 3)
halloffame_ratio = kwargs.get('halloffame_ratio', 0.1)
random_seed = self.random_state
if self.neigh_type == 'genetic':
neighgen = GeneticGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features,
nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1,
alpha2=alpha2, metric=metric, ngen=ngen,
mutpb=mutpb, cxpb=cxpb, tournsize=tournsize,
halloffame_ratio=halloffame_ratio, random_seed=random_seed,
verbose=self.verbose)
elif self.neigh_type == 'rndgen':
neighgen = RandomGeneticGenerator(self.bb_predict, self.feature_values, self.features_map,
nbr_features, nbr_real_features, numeric_columns_index,
ocr=ocr, alpha1=alpha1, alpha2=alpha2,
metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb,
tournsize=tournsize, halloffame_ratio=halloffame_ratio,
random_seed=random_seed, verbose=self.verbose)
elif self.neigh_type == 'geneticp':
bb_predict_proba = kwargs.get('bb_predict_proba', None)
neighgen = GeneticProbaGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features,
nbr_real_features, numeric_columns_index, ocr=ocr, alpha1=alpha1,
alpha2=alpha2, metric=metric, ngen=ngen,
mutpb=mutpb, cxpb=cxpb, tournsize=tournsize,
halloffame_ratio=halloffame_ratio,
bb_predict_proba=bb_predict_proba,
random_seed=random_seed,
verbose=self.verbose)
elif self.neigh_type == 'rndgenp':
bb_predict_proba = kwargs.get('bb_predict_proba', None)
neighgen = RandomGeneticProbaGenerator(self.bb_predict, self.feature_values, self.features_map,
nbr_features, nbr_real_features, numeric_columns_index,
ocr=ocr, alpha1=alpha1, alpha2=alpha2,
metric=metric, ngen=ngen, mutpb=mutpb, cxpb=cxpb,
tournsize=tournsize, halloffame_ratio=halloffame_ratio,
bb_predict_proba=bb_predict_proba,
random_seed=random_seed, verbose=self.verbose)
elif self.neigh_type == 'random':
neighgen = RandomGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features,
nbr_real_features, numeric_columns_index, ocr=ocr)
elif self.neigh_type == 'closest':
Kc = kwargs.get('Kc', None)
k = kwargs.get('k', None)
type = kwargs.get('core_neigh_type', 'simple')
alphaf = kwargs.get('alphaf', 0.5)
alphal = kwargs.get('alphal', 0.5)
metric_features = kwargs.get('metric_features', neuclidean)
metric_labels = kwargs.get('metric_labels', neuclidean)
neighgen = ClosestInstancesGenerator(self.bb_predict, self.feature_values, self.features_map, nbr_features,
nbr_real_features, numeric_columns_index, ocr=ocr,
K=Kc, rK=self.K, k=k, core_neigh_type=type, alphaf=alphaf,
alphal=alphal, metric_features=metric_features,
metric_labels=metric_labels, categorical_use_prob=categorical_use_prob,
continuous_fun_estimation=continuous_fun_estimation, size=size,
verbose=self.verbose)
else:
print('unknown neighborhood generator')
raise Exception
self.neighgen_fn = neighgen.generate