Source code for pystoned.plot

# import dependencies
import matplotlib.pyplot as plt
import numpy as np
from .utils import interpolation
from .constant import FUN_PROD, FUN_COST, RED_MOM, RTS_VRS, RTS_CRS


[docs] def plot2d(model, x_select=0, label_name="estimated function", fig_name=None, method=RED_MOM): """Plot 2d estimated function/frontier Args: model: The input model for plotting. x_select (Integer): The selected x for plotting. label_name (String): the estimator name. fig_name (String, optional): The name of figure to save. Defaults to None. """ x = np.array(model.x).T[x_select] y = np.array(model.y).T if y.ndim != 1: print("Plot with mutiple y is unavailable now.") return False if model.__class__.__name__ != "StoNED": yhat = np.array(model.get_frontier()).T else: yhat = np.array(model.get_frontier(method)).T data = (np.stack([x, y, yhat], axis=0)).T # sort data = data[np.argsort(data[:, 0])].T x, y, f = data[0], data[1], data[2] # create figure and axes objects fig, ax = plt.subplots() dp = ax.scatter(x, y, color="k", marker='x') fl = ax.plot(x, f, color="r", label=label_name) # add legend legend = plt.legend([dp, fl[0]], ['Data points', label_name], loc='upper left', ncol=1, fontsize=10, frameon=False) # add x, y label ax.set_xlabel("Input $x$") ax.set_ylabel("Output $y$") # Remove top and right axes ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) if fig_name == None: plt.show() else: plt.savefig(fig_name)
[docs] def plot3d(model, x_select_1=0, x_select_2=1, fig_name=None, line_transparent=False, pane_transparent=False): """Plot 3d estimated function/frontier Args: model: The input model for plotting. x_select_1 (Integer): The selected x for plotting. x_select_2 (Integer): The selected x for plotting. fun (String, optional): FUN_PROD (production frontier) or FUN_COST (cost frontier). Defaults to FUN_PROD. fig_name (String, optional): The name of figure to save. Defaults to None. line_transparent (bool, optional): control the transparency of the lines. Defaults to False. pane_transparent (bool, optional): control the transparency of the pane. Defaults to False. """ x = np.array(model.x).T y = np.array(model.y).T if model.rts == RTS_VRS: alpha, beta = model.get_alpha(), model.get_beta() elif model.rts == RTS_CRS: alpha, beta = np.zeros((model.get_beta()).shape[0]), model.get_beta() if y.ndim != 1: print("Plot with mutiple y is unavailable now.") return False fig = plt.figure() ax = fig.add_subplot(projection='3d') dp = ax.scatter(x[x_select_1], x[x_select_2], y, marker='.', s=10) # Revise the Z-axis left side tmp_planes = ax.zaxis._PLANES ax.zaxis._PLANES = (tmp_planes[2], tmp_planes[3], tmp_planes[0], tmp_planes[1], tmp_planes[4], tmp_planes[5]) # make the grid lines transparent if line_transparent == False: ax.xaxis._axinfo["grid"]['color'] = (1, 1, 1, 0) ax.yaxis._axinfo["grid"]['color'] = (1, 1, 1, 0) ax.zaxis._axinfo["grid"]['color'] = (1, 1, 1, 0) # make the panes transparent if pane_transparent != False: ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) xlin1 = np.linspace(min(x[x_select_1]), max(x[x_select_1]), 30) xlin2 = np.linspace(min(x[x_select_2]), max(x[x_select_2]), 30) XX, YY = np.meshgrid(xlin1, xlin2) ZZ = np.zeros((len(XX), len(XX))) for i in range(len(XX)): for j in range(len(XX)): ZZ[i, j] = interpolation.interpolation(alpha, beta, x=np.array( [XX[i, j], YY[i, j]], ndmin=2), fun=model.fun) ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, cmap='viridis', edgecolor='none', alpha=0.5) # add x, y, z label ax.set_xlabel("Input $x1$") ax.set_ylabel("Input $x2$") ax.set_zlabel("Output $y$", rotation=0) if fig_name == None: plt.show() else: plt.savefig(fig_name)