import matplotlib.pyplot as plt import tree_utils as tu
def plot(D,attr=None,ofn=None): if not attr: attr = get_default_attr() SZ = attr['e_node_dot_size'] lw = attr['line_width'] maxx, maxy = D['meta']['max_xy'] all_node_names = D['meta']['all_node_names'] e_node_names = D['meta']['e_node_names'] i_node_names = D['meta']['i_node_names']
# actual node dictionaries e_node_dicts = [D[n] for n in e_node_names] i_node_dicts = [D[n] for n in i_node_names]
# extract the values we need e_node_positions = [(nD['x'],nD['y']) for nD in e_node_dicts] e_node_lengths = [(nD['dist_to_parent']) for nD in e_node_dicts] i_node_positions = [(nD['x'],nD['y']) for nD in i_node_dicts] i_node_lengths = [(nD['dist_to_parent']) for nD in i_node_dicts] i_node_verticals = [(nD['y_bott'], nD['y_top']) for nD in i_node_dicts]
# external nodes for i in range(len(e_node_names)): x,y = e_node_positions[i] d = e_node_lengths[i] xleft = x - d # bars L = attr['e_node_bar_color_list'] if L: c = L[i] else: c = attr['e_node_bar_color'] plt.plot([xleft,x],[y,y],color=c,lw=lw,zorder=1) # dots name = e_node_names[i] if attr['e_node_dots_visible']: if attr['using_e_node_specific_colors']: c = attr['dot_color_dict'][name] else: c = attr['default_e_node_dot_color'] plt.scatter(x,y,color=c,s=SZ,zorder=2) # internal nodes, root is i = 0 for i in range(len(i_node_names))[1:]: x,y = i_node_positions[i] d = i_node_lengths[i] y_bott, y_top = i_node_verticals[i] x0 = x - d # bars c = attr['i_node_bar_color'] plt.plot([x0,x],[y,y],color=c,lw=lw,zorder=1) # verticals c = attr['i_node_vertical_bar_color'] plt.plot([x,x],[y_bott, y_top],color=c,lw=lw,zorder=1) # dots if attr['i_node_dots_visible']: c = attr['i_node_dot_color'] plt.scatter(x,y,color=c,s=SZ,zorder=2) # root i = 0 x,y = i_node_positions[i] y_bott, y_top = i_node_verticals[i] # verticals c = attr['i_node_vertical_bar_color'] plt.plot([x,x],[y_bott, y_top],color=c,lw=lw,zorder=1) # dots if attr['i_node_dots_visible']: c = attr['r_node_dot_color'] plt.scatter(x,y,color=c,s=SZ,zorder=2)
L = [len(name) for name in e_node_names] max_label_width = maxx/100.0 * max(L) max_label_width *= attr['label_width_factor']
# external node labels if attr['e_node_labels_visible']: for i in range(len(e_node_names)): x,y = e_node_positions[i] dx = maxx/100.0 * 4 name = e_node_names[i] if attr['using_alternate_names']: s = attr['alternate_name_dict'][name] else: s = name if attr['using_e_node_specific_colors']: c = attr['node_label_color_dict'][name] else: c = attr['e_node_label_default_color'] plt.text(x + dx, y, s, fontname = 'Helvetica', fontsize = attr['e_node_label_size'], color = c, ha = 'left',va = 'center')
ax = plt.axes() if not attr['vertical_axis_visible']: ax.yaxis.set_visible(False) if not attr['horizontal_axis_visible']: ax.xaxis.set_visible(False) ax.set_xlim(-maxx/10.0,maxx*1.1 + max_label_width) ax.set_ylim(-maxy/10.0,maxy*1.1) if ofn: plt.savefig(ofn + '.' + attr['figure_type']) else: plt.savefig('example.' + attr['figure_type']) def get_default_attr(): attr = dict() attr['e_node_dot_size'] = 75 attr['line_width'] = 2 attr['figure_type'] = 'png' attr['horizontal_axis_visible'] = True attr['vertical_axis_visible'] = False attr['e_node_bar_color'] = 'r' attr['e_node_bar_color_list'] = None attr['e_node_dots_visible'] = True attr['e_node_labels_visible'] = True attr['default_e_node_dot_color'] = 'k' attr['e_node_label_size'] = 14 attr['e_node_label_default_color'] = 'k' attr['label_width_factor'] = 1.1 attr['i_node_bar_color'] = 'k' attr['i_node_vertical_bar_color'] = 'k' attr['i_node_dots_visible'] = False attr['i_node_dot_color'] = 'magenta' attr['r_node_dot_color'] = 'orange' attr['using_alternate_names'] = False attr['alternate_name_dict'] = None attr['using_e_node_specific_colors'] = False attr['node_label_color_dict'] = None attr['dot_color_dict'] = None return attr def print_defaults(): attr = get_default_attr() N = max([len(k) for k in attr.keys()]) for k in sorted(attr.keys()): v = attr[k] k = k.replace('_',' ') print k.rjust(N), ' ', v
if __name__ == '__main__': fn = 'tree.txt' file_data = tu.load_data(fn) D = tu.make_tree_dict(ts=file_data) plot(D) print_defaults() |