import os
import pickle
from math import ceil
from glob import glob
from itertools import product
from collections import defaultdict
import numpy as np
from flopt import env as flopt_env
from flopt.solvers.solver_utils.common import value2str
from flopt.env import setup_logger
PERFORMANCE_DIR = flopt_env.PERFORMANCE_DIR
logger = setup_logger(__name__)
[docs]class LogVisualizer:
"""
Log visualizer from logs.
We input logs by constructor or loading from performance directory.
Parameters
----------
logs : dict
logs[dataset, instance, solver_name] = log
Examples
--------
.. code-block:: python
log_visualizer = LogVisualiser()
log_visualizer.load(
solver_names=['Random', '2-Opt'],
datasets=['tsp']
)
log_visualizer.plot()
"""
def __init__(self, logs=None):
if logs is None:
self.logs = {}
else:
self.logs = logs
def load(self, solver_names, datasets, load_prefix=PERFORMANCE_DIR):
if isinstance(solver_names, str):
solver_names = [solver_names]
if isinstance(datasets, str):
datasets = [datasets]
for solver_name, dataset in product(solver_names, datasets):
self.load_log(solver_name, dataset, load_prefix)
[docs] def load_log(self, solver_name, dataset, load_prefix=PERFORMANCE_DIR):
"""
load log pickle file from load_prefix/solver_name/dataset/instance/log.pickle
Parameters
----------
solver_name : str
solver name
dataset : str
dataset name
load_prefix : str
log saved path
"""
for picklefile in glob(f"{load_prefix}/{solver_name}/{dataset}/*/log.pickle"):
instance_name = picklefile.split("/")[-2]
with open(picklefile, "rb") as pf:
self.logs[dataset, instance_name, solver_name] = pickle.load(pf)
[docs] def plot(
self, xitem="time", yscale="linear", plot_type="all", save_prefix=None, col=2
):
"""
plot all logs
Parameters
----------
xitem : str
x-label name. 'time' or 'iteration'
yscale : str
linear or log
plot_type : str
all: create figures for each dataset.
each: create figures for each instance.
noshow: do not create figures.
col : int
#columns of figure
"""
import matplotlib.pyplot as plt
datasets = set(dataset for dataset, _, _ in self.logs)
for dataset in datasets:
instances = set(i for d, i, _ in self.logs if d == dataset)
n_instance = len(instances)
col = 1 if n_instance < 2 else col
row = ceil(n_instance / col)
if plot_type == "all":
fig, axs = plt.subplots(row, col)
fig.suptitle(dataset)
if n_instance == 1:
axs = [axs]
elif plot_type == "each":
if n_instance > 1:
axs = np.ndarray((row, col))
else:
axs = np.array(1)
instances_iter = instances # add sorted
for instance, ax in zip(instances_iter, iter_axs(axs, col)):
if plot_type == "each":
fig, ax = plt.subplots()
solver_names = set(
s for d, i, s in self.logs if (d, i) == (dataset, instance)
)
for solver_name in solver_names:
log = self.logs[dataset, instance, solver_name]
log.plot(
show=False,
xitem=xitem,
linestyle="--",
marker=".",
label=solver_name,
fig=fig,
ax=ax,
)
setax(ax, instance, yscale)
if plot_type == "each":
if save_prefix is None:
plt.show()
else:
save_fig(fig, f"{save_prefix}{dataset}/{instance}.pdf")
if plot_type == "all":
if save_prefix is None:
plt.show()
else:
save_fig(fig, f"{save_prefix}{dataset}.pdf")
[docs] def stat(self, time=None, iteration=None):
"""display static information
Parameters
----------
time : int or float
summary logs whose time less than time
iteration : int
summary logs whose iteration less than iteration
"""
logger.debug(f"summary logs time={time}, iteration={iteration}")
datasets = set(dataset for dataset, _, _ in self.logs)
for dataset in datasets:
stat_message_header = ["", "", f"{dataset}", "=" * len(dataset), ""]
solvers = sorted(set(s for d, _, s in self.logs if d == dataset))
instances = sorted(set(i for d, i, _ in self.logs if d == dataset))
stat_messages = []
stat = {"num_wins": [0] * len(solvers), "score": [0] * len(solvers)}
for instance in instances:
stat_message = [instance]
obj_values = []
for i, solver in enumerate(solvers):
if (dataset, instance, solver) in self.logs:
logs = self.logs[dataset, instance, solver]
log = logs.getLog(time=time, iteration=iteration)
obj_value = log["obj_value"]
stat_message.append(value2str(obj_value))
obj_values.append(obj_value)
else:
stat_message.append("")
obj_values.append(float("inf"))
stat_messages.append(stat_message)
sorted_solver_ixs = np.argsort(obj_values)
win_solver_ix = sorted_solver_ixs[0]
stat["num_wins"][win_solver_ix] += 1
best_obj = obj_values[sorted_solver_ixs[0]]
n_win = 1
while n_win < len(obj_values):
obj = obj_values[sorted_solver_ixs[n_win]]
if abs(best_obj - obj) < 1e-10:
stat["num_wins"][sorted_solver_ixs[n_win]] += 1
n_win += 1
else:
break
for i, solver_ix in enumerate(sorted_solver_ixs):
if i < n_win:
stat["score"][solver_ix] += 0
else:
stat["score"][solver_ix] += i
print("\n".join(stat_message_header))
messages = []
messages.append(["Instance"] + solvers)
messages.append(["-" * 8] + ["-" * len(s) for s in solvers])
messages += stat_messages
messages.append([""])
messages.append(["#Win"] + list(map(str, stat["num_wins"])))
messages.append(["Score"] + list(map(str, stat["score"])))
# simple ranking
sorted_solver_ix = np.argsort(stat["score"])
ranks = [0] * len(solvers)
for i, solver_ix in enumerate(sorted_solver_ix, 1):
ranks[solver_ix] = i
messages.append(["Ranking"] + list(map(str, ranks)))
arr_dict = defaultdict(int)
for message in messages:
for col, el in enumerate(message):
arr_dict[col] = max(arr_dict[col], len(el))
for i, message in enumerate(messages):
message = [
"{}{}".format(" " * (arr_dict[col] - len(el)), el)
for col, el in enumerate(message)
]
messages[i] = message
for message in messages:
print(" ".join(message))
print()
print(f"stas")
print(f"s# time,{','.join(solvers)}")
print(f"s## {time},{','.join(list(map(str, ranks)))}")
def __len__(self):
return len(self.logs)
def save_fig(fig, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
fig.savefig(path, bbox_inches="tight")
plt.close()
def iter_axs(axs, col):
i = 0
j = 0
while True:
if col > 1:
yield axs[i, j]
else:
yield axs[j]
if j < col - 1:
j += 1
else:
i, j = i + 1, 0
def setax(ax, title, yscale):
ax.grid("--")
ax.legend()
ax.set_title(title)
ax.set_yscale(yscale)