Source code for garage.experiment.nb_utils

import csv
from glob import glob
import json
import os
import os.path as osp

import joblib
import matplotlib.pyplot as plt
import numpy as np


[docs]def plot_experiments(name_or_patterns, legend=False, post_processing=None, key='AverageReturn'): if not isinstance(name_or_patterns, (list, tuple)): name_or_patterns = [name_or_patterns] data_folder = osp.abspath(osp.join(osp.dirname(__file__), '../../data')) files = [] for name_or_pattern in name_or_patterns: matched_files = glob(osp.join(data_folder, name_or_pattern)) files += matched_files files = sorted(files) print('plotting the following experiments:') for f in files: print(f) plots = [] legends = [] for f in files: exp_name = osp.basename(f) returns = [] with open(osp.join(f, 'progress.csv'), 'rb') as csvfile: reader = csv.DictReader(csvfile) for row in reader: if row[key]: returns.append(float(row[key])) returns = np.array(returns) if post_processing: returns = post_processing(returns) plots.append(plt.plot(returns)[0]) legends.append(exp_name) if legend: plt.legend(plots, legends)
[docs]class Experiment: def __init__(self, progress, params, pkl_data=None): self.progress = progress self.params = params self.pkl_data = pkl_data self.flat_params = self._flatten_params(params) self.name = params['exp_name'] def _flatten_params(self, params, depth=2): flat_params = dict() for k, v in params.items(): if isinstance(v, dict) and depth != 0: for subk, subv in self._flatten_params( v, depth=depth - 1).items(): if subk == '_name': flat_params[k] = subv else: flat_params[k + '_' + subk] = subv else: flat_params[k] = v return flat_params
[docs]def uniq(seq): seen = set() seen_add = seen.add return [x for x in seq if not (x in seen or seen_add(x))]
[docs]class ExperimentDatabase: def __init__(self, data_folder, names_or_patterns='*'): self._load_experiments(data_folder, names_or_patterns) def _read_data(self, progress_file): entries = dict() with open(progress_file, 'rb') as csvfile: reader = csv.DictReader(csvfile) for row in reader: for k, v in row.items(): if k not in entries: entries[k] = [] entries[k].append(float(v)) entries = dict([(k, np.array(v)) for k, v in entries.items()]) return entries def _read_params(self, params_file): with open(params_file, 'r') as f: return json.loads(f.read()) def _load_experiments(self, data_folder, name_or_patterns): if not isinstance(name_or_patterns, (list, tuple)): name_or_patterns = [name_or_patterns] files = [] for name_or_pattern in name_or_patterns: matched_files = glob( osp.join(data_folder, name_or_pattern) ) # golb gives a list of all files satisfying pattern # this will include twice the same file if it satisfies 2 patterns files += matched_files experiments = [] progress_f = None params_f = None pkl_data = None for f in files: if os.path.isdir(f): try: progress = self._read_data(osp.join(f, 'progress.csv')) params = self._read_params(osp.join(f, 'params.json')) params['exp_name'] = osp.basename(f) if os.path.isfile(osp.join(f, 'params.pkl')): pkl_data = joblib.load(osp.join(f, 'params.pkl')) experiments.append( Experiment(progress, params, pkl_data)) else: experiments.append(Experiment(progress, params)) except Exception as e: print(e) elif 'progress.csv' in f: # in case you're giving as datafolder the dir that contains # the files! progress_f = self._read_data(f) elif 'params.json' in f: params_f = self._read_params(f) elif 'params.pkl' in f: print('about to load', f) pkl_data = joblib.load(f) if params_f and progress_f: if pkl_data: experiments.append(Experiment(progress_f, params_f, pkl_data)) else: experiments.append(Experiment(progress_f, params_f)) self._experiments = experiments
[docs] def plot_experiments(self, key=None, legend=None, color_key=None, filter_exp=None, **kwargs): experiments = list(self.filter_experiments(**kwargs)) if filter_exp: experiments = list(filter(filter_exp, experiments)) plots = [] legends = [] color_pool = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'w'] color_map = dict() if color_key is not None: exp_color_keys = uniq( [exp.flat_params.get(color_key, None) for exp in experiments]) if len(exp_color_keys) > len(color_pool): raise NotImplementedError for exp_color_key, color in zip(exp_color_keys, color_pool): print('%s: %s' % (str(exp_color_key), color)) color_map = dict(list(zip(exp_color_keys, color_pool))) used_legends = [] legend_list = [] for exp in experiments: exp_color_key = None if color_key is not None: exp_color_key = exp.flat_params.get(color_key, None) exp_color = color_map.get(exp_color_key, None) else: exp_color = None plots.append( plt.plot(exp.progress.get(key, [0]), color=exp_color)[0]) if legend is not None: legends.append(exp.flat_params[legend]) elif (exp_color_key is not None and exp_color_key not in used_legends): used_legends.append(exp_color_key) legend_list.append(plots[-1]) if legends: plt.legend(plots, legends) elif legend_list: plt.legend(legend_list, used_legends)
[docs] def filter_experiments(self, **kwargs): for exp in self._experiments: exp_params = exp.flat_params match = True for key, val in kwargs.items(): if exp_params.get(key, None) != val: match = False break if match: yield exp
[docs] def unique(self, param_key): return uniq([ exp.flat_params[param_key] for exp in self._experiments if param_key in exp.flat_params ])