import copy
import json
import numpy as np
from .util import vector2dict, multilabel2str
from collections import defaultdict
[docs]class Condition(object):
def __init__(self, att, op, thr, is_continuous=True):
self.att = att
self.op = op
self.thr = thr
self.is_continuous = is_continuous
def __str__(self):
if self.is_continuous:
return '%s %s %.2f' % (self.att, self.op, self.thr)
else:
att_split = self.att.split('=')
sign = '=' if self.op == '>' else '!='
return '%s %s %s' % (att_split[0], sign, att_split[1])
def __eq__(self, other):
return self.att == other.att and self.op == other.op and self.thr == other.thr
def __hash__(self):
return hash(str(self))
[docs]class Rule(object):
def __init__(self, premises, cons, class_name):
self.premises = premises
self.cons = cons
self.class_name = class_name
def _pstr(self):
return '{ %s }' % (', '.join([str(p) for p in self.premises]))
def _cstr(self):
if not isinstance(self.class_name, list):
return '{ %s: %s }' % (self.class_name, self.cons)
else:
return '{ %s }' % self.cons
def __str__(self):
return '%s --> %s' % (self._pstr(), self._cstr())
def __eq__(self, other):
return self.premises == other.premises and self.cons == other.cons
def __len__(self):
return len(self.premises)
def __hash__(self):
return hash(str(self))
[docs] def is_covered(self, x, feature_names):
xd = vector2dict(x, feature_names)
for p in self.premises:
if p.op == '<=' and xd[p.att] > p.thr:
return False
elif p.op == '>' and xd[p.att] <= p.thr:
return False
return True
[docs]def json2cond(obj):
return Condition(obj['att'], obj['op'], obj['thr'], obj['is_continuous'])
[docs]def json2rule(obj):
premises = [json2cond(p) for p in obj['premise']]
cons = obj['cons']
class_name = obj['class_name']
return Rule(premises, cons, class_name)
[docs]class NumpyEncoder(json.JSONEncoder):
""" Special json encoder for numpy types """
[docs] def default(self, obj):
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)):
return int(obj)
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
return float(obj)
elif isinstance(obj, (np.ndarray,)):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
[docs]class ConditionEncoder(json.JSONEncoder):
""" Special json encoder for Condition types """
[docs] def default(self, obj):
if isinstance(obj, Condition):
json_obj = {
'att': obj.att,
'op': obj.op,
'thr': obj.thr,
'is_continuous': obj.is_continuous,
}
return json_obj
return json.JSONEncoder.default(self, obj)
[docs]class RuleEncoder(json.JSONEncoder):
""" Special json encoder for Rule types """
[docs] def default(self, obj):
if isinstance(obj, Rule):
ce = ConditionEncoder()
json_obj = {
'premise': [ce.default(p) for p in obj.premises],
'cons': obj.cons,
'class_name': obj.class_name
}
return json_obj
return json.JSONEncoder.default(self, obj)
[docs]def get_rule(x, dt, feature_names, class_name, class_values, numeric_columns, multi_label=False):
x = x.reshape(1, -1)
feature = dt.tree_.feature
threshold = dt.tree_.threshold
leave_id = dt.apply(x)
node_index = dt.decision_path(x).indices
premises = list()
for node_id in node_index:
if leave_id[0] == node_id:
break
else:
op = '<=' if x[0][feature[node_id]] <= threshold[node_id] else '>'
att = feature_names[feature[node_id]]
thr = threshold[node_id]
iscont = att in numeric_columns
premises.append(Condition(att, op, thr, iscont))
dt_outcome = dt.predict(x)[0]
cons = class_values[int(dt_outcome)] if not multi_label else multilabel2str(dt_outcome, class_values)
premises = compact_premises(premises)
return Rule(premises, cons, class_name)
[docs]def get_depth(dt):
n_nodes = dt.tree_.node_count
children_left = dt.tree_.children_left
children_right = dt.tree_.children_right
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
stack = [(0, -1)] # seed is the root node id and its parent depth
while len(stack) > 0:
node_id, parent_depth = stack.pop()
node_depth[node_id] = parent_depth + 1
# If we have a test node
if children_left[node_id] != children_right[node_id]:
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
depth = np.max(node_depth)
return depth
[docs]def get_rules(dt, feature_names, class_name, class_values, numeric_columns, multi_label=False):
n_nodes = dt.tree_.node_count
feature = dt.tree_.feature
threshold = dt.tree_.threshold
children_left = dt.tree_.children_left
children_right = dt.tree_.children_right
value = dt.tree_.value
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, -1)] # seed is the root node id and its parent depth
reverse_dt_dict = dict()
left_right = dict()
while len(stack) > 0:
node_id, parent_depth = stack.pop()
# If we have a test node
if children_left[node_id] != children_right[node_id]:
stack.append((children_left[node_id], parent_depth + 1))
stack.append((children_right[node_id], parent_depth + 1))
reverse_dt_dict[children_left[node_id]] = node_id
left_right[(node_id, children_left[node_id])] = 'l'
reverse_dt_dict[children_right[node_id]] = node_id
left_right[(node_id, children_right[node_id])] = 'r'
else:
is_leaves[node_id] = True
node_index_list = list()
for node_id in range(n_nodes):
if is_leaves[node_id]:
node_index = [node_id]
parent_node = reverse_dt_dict.get(node_id, None)
while parent_node:
node_index.insert(0, parent_node)
parent_node = reverse_dt_dict.get(parent_node, None)
if node_index[0] != 0:
node_index.insert(0, 0)
node_index_list.append(node_index)
if len(value) > 1:
value = np.argmax(value.reshape(len(value), 2), axis=1)
rules = list()
for node_index in node_index_list:
premises = list()
for i in range(len(node_index) - 1):
node_id = node_index[i]
child_id = node_index[i+1]
op = '<=' if left_right[(node_id, child_id)] == 'l' else '>'
att = feature_names[feature[node_id]]
thr = threshold[node_id]
iscont = att in numeric_columns
premises.append(Condition(att, op, thr, iscont))
cons = class_values[int(value[node_index[-1]])] if not multi_label else multilabel2str(
value[node_index[-1]], class_values)
premises = compact_premises(premises)
rules.append(Rule(premises, cons, class_name))
else:
x = np.zeros(len(feature_names)).reshape(1, -1)
dt_outcome = dt.predict(x)[0]
cons = class_values[int(dt_outcome)] if not multi_label else multilabel2str(dt_outcome, class_values)
rules = [Rule([], cons, class_name)]
return rules
[docs]def compact_premises(plist):
att_list = defaultdict(list)
for p in plist:
att_list[p.att].append(p)
compact_plist = list()
for att, alist in att_list.items():
if len(alist) > 1:
min_thr = None
max_thr = None
for av in alist:
if av.op == '<=':
max_thr = min(av.thr, max_thr) if max_thr else av.thr
elif av.op == '>':
min_thr = max(av.thr, min_thr) if min_thr else av.thr
if max_thr:
compact_plist.append(Condition(att, '<=', max_thr))
if min_thr:
compact_plist.append(Condition(att, '>', min_thr))
else:
compact_plist.append(alist[0])
return compact_plist
[docs]def get_counterfactual_rules(x, y, dt, Z, Y, feature_names, class_name, class_values, numeric_columns, features_map,
features_map_inv, bb_predict=None, multi_label=False):
clen = np.inf
crule_list = list()
delta_list = list()
Z1 = Z[np.where(Y != y)[0]]
xd = vector2dict(x, feature_names)
for z in Z1:
crule = get_rule(z, dt, feature_names, class_name, class_values, numeric_columns, multi_label)
delta, qlen = get_falsified_conditions(xd, crule)
if bb_predict is not None:
xc = apply_counterfactual(x, delta, feature_names, features_map, features_map_inv, numeric_columns)
bb_outcomec = bb_predict(xc.reshape(1, -1))[0]
bb_outcomec = class_values[bb_outcomec] if isinstance(class_name, str) else multilabel2str(bb_outcomec,
class_values)
dt_outcomec = crule.cons
# print(bb_outcomec, dt_outcomec, bb_outcomec == dt_outcomec)
if bb_outcomec == dt_outcomec:
if qlen < clen:
clen = qlen
crule_list = [crule]
delta_list = [delta]
elif qlen == clen:
# print([[str(s1) for s1 in s] for s in delta_list])
if delta not in delta_list:
crule_list.append(crule)
delta_list.append(delta)
else:
if qlen < clen:
clen = qlen
crule_list = [crule]
delta_list = [delta]
elif qlen == clen:
# print([[str(s1) for s1 in s] for s in delta_list])
if delta not in delta_list:
crule_list.append(crule)
delta_list.append(delta)
# if bb_predict is not None:
# cleaned_crules = list()
# cleaned_deltas = list()
# for crule, delta in zip(crule_list, delta_list):
# xc = apply_counterfactual(x, delta, feature_names, features_map, features_map_inv, numeric_columns)
#
# bb_outcomec = bb_predict(xc.reshape(1, -1))[0]
# bb_outcomec = class_values[bb_outcomec] if isinstance(class_name, str) else multilabel2str(bb_outcomec,
# class_values)
# dt_outcomec = crule.cons
# if bb_outcomec == dt_outcomec:
# cleaned_crules.append(crule)
# cleaned_deltas.append(delta)
#
# crule_list = cleaned_crules
# delta_list = cleaned_deltas
return crule_list, delta_list
[docs]def get_falsified_conditions(xd, crule):
delta = list()
nbr_falsified_conditions = 0
for p in crule.premises:
if p.op == '<=' and xd[p.att] > p.thr:
delta.append(p)
nbr_falsified_conditions += 1
elif p.op == '>' and xd[p.att] <= p.thr:
delta.append(p)
nbr_falsified_conditions += 1
return delta, nbr_falsified_conditions
[docs]def apply_counterfactual(x, delta, feature_names, features_map=None, features_map_inv=None, numeric_columns=None):
xd = vector2dict(x, feature_names)
xcd = copy.deepcopy(xd)
for p in delta:
if p.att in numeric_columns:
if p.thr == int(p.thr):
gap = 1.0
else:
decimals = list(str(p.thr).split('.')[1])
for idx, e in enumerate(decimals):
if e != '0':
break
gap = 1 / (10**(idx+1))
if p.op == '>':
xcd[p.att] = p.thr + gap
else:
xcd[p.att] = p.thr
else:
fn = p.att.split('=')[0]
if p.op == '>':
if features_map is not None:
fi = list(feature_names).index(p.att)
fi = features_map_inv[fi]
for fv in features_map[fi]:
xcd['%s=%s' % (fn, fv)] = 0.0
xcd[p.att] = 1.0
else:
if features_map is not None:
fi = list(feature_names).index(p.att)
fi = features_map_inv[fi]
for fv in features_map[fi]:
xcd['%s=%s' % (fn, fv)] = 1.0
xcd[p.att] = 0.0
xc = np.zeros(len(xd))
for i, fn in enumerate(feature_names):
xc[i] = xcd[fn]
return xc