import dot2tex
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from .stats import ResultsPartiallyConfounded, ResultsFullyConfounded
def _pval_to_str(pval, alpha=0.05, floor=0.0001):
if pval is None:
return ''
elif pval > alpha:
return 'p=' + str(np.round(pval, 2))
elif pval > floor:
return 'p=' + str(np.round(pval, np.ceil(-np.log10(pval)).astype(int)))
else:
return 'p<' + str(floor)
[docs]def plot_null_dist(confound_test_results, **kwargs):
"""
Plot the histogram of the permutation-based null distribution of the confounder test.
Parameters
----------
confound_test_results: namedtuple
The object returned by `test_partially_confounded` or `test_fully_confounded`.
kwargs : dict
Additional named argumnets, passed to seaborn.histplot.
Returns
--------
matplotlib.axes.Axes
The matplotlib axes containing the plot.
See Also
---------
plot_graph
"""
if not hasattr(confound_test_results, 'null_distribution'):
raise RuntimeError("No null distribution data is available. "
"Re-run 'confound_test' with 'return_null_dist=True'!")
g = sns.histplot(confound_test_results.null_distribution, **kwargs)
g.set(xlabel='R2(y^,c*)', ylabel='count')
g.set_title('null distribution')
plt.axvline(confound_test_results.r2_yhat_c, color='red')
return g
[docs]def plot_graph(confound_test_results, y_name='y', yhat_name='<ŷ>', c_name='c', outfile_base=None, precision=3):
"""
Plot confounder test results as a graph depicting the involved variables.
Parameters
----------
confound_test_results : namedtuple
The object returned by `test_partially_confounded` or `test_fully_confounded`.
y_name: str
Name of the target variable.
yhat_name: str
Name for the model predictions.
c_name: str
Name of the confounder variable.
outfile_base: str
Path for output files, without extension (None: figure not saved).
precision: int
Precision for r squared values.
Returns
-------
dot
The graphviz object to plot.
See Also
--------
plot_null_dist
"""
if isinstance(confound_test_results, ResultsPartiallyConfounded):
mode = 'partial'
else:
mode = 'full'
return plot_r2_graph(confound_test_results.r2_y_c,
confound_test_results.r2_yhat_c,
confound_test_results.r2_y_yhat,
confound_test_results.p,
y_name=y_name, yhat_name=yhat_name, c_name=c_name,
mode=mode,
outfile_base=outfile_base,
precision=precision)
def plot_r2_graph(r2_y_c, r2_yhat_c, r2_y_yhat, p_yhat_c=None,
y_name='y', yhat_name='yhat', c_name='c',
mode='partial',
outfile_base=None,
precision=3,
alpha=0.05,
minp=0.0001):
dot = graphviz.Graph()
dot.attr(rankdir='LR')
if mode != 'partial' and mode != 'full':
raise AttributeError("Mode must be either 'partial' or 'full'.")
if p_yhat_c < alpha:
star = '*'
else:
star = ''
pvalstr = ' (' + _pval_to_str(p_yhat_c, alpha, minp) + star + ')'
dot.node('c', label=c_name)
dot.node('y', label=y_name)
dot.node('yhat', label=yhat_name)
if mode == 'partial':
dot.edge('c', 'yhat', label=str(np.round(r2_yhat_c, precision)) + pvalstr, style="dashed")
else:
dot.edge('c', 'yhat', label=str(np.round(r2_yhat_c, precision)))
dot.edge('c', 'y', label=str(np.round(r2_y_c, precision)))
if mode == 'full':
dot.edge('y', 'yhat', label=str(np.round(r2_y_yhat, precision)) + pvalstr, style="dashed")
else:
dot.edge('y', 'yhat', label=str(np.round(r2_y_yhat, precision)))
if outfile_base is not None:
dot.render(filename=outfile_base + '.dot') # saves dot and pdf
tex_code = dot2tex.dot2tex(dot.source, format='tikz', crop=True)
with open(outfile_base + '.tex', "w") as text_file:
text_file.write(tex_code)
return dot