library('reticulate')
use_condaenv(condaenv = 'STOCK_MASTER', required = TRUE)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def linex(y, yhat, alpha = 0.5):
= y - yhat
error return np.exp(alpha * np.sign(y) * error) - (alpha * np.sign(y)) * error - 1
Note that ax.view_init(elev=30, azim=-60)
are the
defaults that set the angle and rotation of viewing.
elev
: elevation, the angle above/below the x-y
axis
azim
: azimuth, rotation about the z axis
# Data
= np.linspace(-1.5, 1.5, 30)
ys = np.linspace(-1.5, 1.5, 30)
yhats = np.meshgrid(ys, yhats)
y, yhat = linex(y, yhat)
z
# Plot
= plt.axes(projection='3d');
ax =1, cstride=1, cmap='viridis', edgecolor='none');
ax.plot_surface(y, yhat, z, rstride'y');
ax.set_xlabel('yhat');
ax.set_ylabel('loss');
ax.set_zlabel('Linear exponenial');
ax.set_title(10, -15);
ax.view_init(0, 3);
ax.set_zlim(; plt.show()
# Function
def plot_fun(gr=1, ax=None, **kwargs):
# Use the axes you have specified or gets the current axes from matplotlib
#ax = ax or plt.gca()
= plt.axes(projection='3d')
ax
# Data
= np.linspace(gr*-1, gr, 30)
ys = ys
yhats = np.meshgrid(ys, yhats)
y, yhat = linex(y, yhat)
z =1, cstride=1, cmap='viridis', edgecolor='none')
ax.plot_surface(y, yhat, z, rstride'y')
ax.set_xlabel('yhat')
ax.set_ylabel('loss')
ax.set_zlabel('Linear exponenial')
ax.set_title(
return ax
# Plot
=0.5);
plot_fun(gr; plt.show()