from sklearn.datasets import *
X,y = make_blobs(n_samples=1000,
n_features=3,
centers=None,
cluster_std=1.0,)
print(X.shape)
print(y.shape)
print(X[0])
(1000, 3) (1000,) [-1.87983687 -9.29421918 -7.51118103]
n_informative
-dimensional hypercube with sides of length=2*class_sep
.X,y = make_classification(n_samples=100,
n_features=20,
n_informative=2,
n_redundant=2,
n_repeated=0,
n_classes=2,
n_clusters_per_class=2,
weights=None,
flip_y=0.01,
class_sep=1.0,
hypercube=True,
shift=0.0,
scale=1.0,
shuffle=True,
random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 20) (100,) [-0.5327797 -1.01850853 1.96627731 -1.37423585 0.28955284 0.60498103 -0.87262662 -1.2460967 -2.27388135 -0.67759805 -0.49377949 0.27307644 0.31466699 -1.40146697 -0.6476324 1.52066461 2.1903122 0.92581289 -1.04340091 0.00964815]
X,y = make_gaussian_quantiles(mean=None,
cov=1.0,
n_samples=100,
n_features=2,
n_classes=3,
shuffle=True,
random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 2) (100,) [-0.58473374 -0.76341671]
X,y = make_circles(n_samples=100,
shuffle=True,
noise=None,
random_state=None,
factor=0.8)
print(X.shape)
print(y.shape)
print(X[0])
(100, 2) (100,) [-0.96858316 0.24868989]
X,y = make_moons(n_samples=100,
shuffle=True,
noise=None,
random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 2) (100,) [ 0.23855404 -0.1482284 ]
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_multilabel_classification as make_ml_clf
COLORS = np.array(['!',
'#FF3333', # red
'#0198E1', # blue
'#BF5FFF', # purple
'#FCD116', # yellow
'#FF7216', # orange
'#4DBD33', # green
'#87421F' # brown
])
# Use same random seed for multiple calls to make_multilabel_classification to
# ensure same distributions
RANDOM_SEED = np.random.randint(2 ** 10)
def plot_2d(ax, n_labels=1, n_classes=3, length=50):
X, Y, p_c, p_w_c = make_ml_clf(n_samples=150,
n_features=2,
n_classes=n_classes,
n_labels=n_labels,
length=length,
allow_unlabeled=False,
return_distributions=True,
random_state=RANDOM_SEED)
ax.scatter(X[:, 0], X[:, 1],
color=COLORS.take((Y * [1, 2, 4]).sum(axis=1)),
marker='.')
ax.scatter(p_w_c[0] * length,
p_w_c[1] * length,
marker='*', linewidth=.5, edgecolor='black',
s=20 + 1500 * p_c ** 2,
color=COLORS.take([1, 2, 4]))
ax.set_xlabel('Feature 0 count')
return p_c, p_w_c
_, (ax1, ax2) = plt.subplots(1, 2,
sharex='row',
sharey='row',
figsize=(8, 4))
plt.subplots_adjust(bottom=.15)
p_c, p_w_c = plot_2d(ax1, n_labels=1)
ax1.set_title('n_labels=1, length=50')
ax1.set_ylabel('Feature 1 count')
plot_2d(ax2, n_labels=3)
ax2.set_title('n_labels=3, length=50')
ax2.set_xlim(left=0, auto=True)
ax2.set_ylim(bottom=0, auto=True)
plt.show()
print('Class', 'P(C)', 'P(w0|C)', 'P(w1|C)', sep='\t')
for k, p, p_w in zip(['red', 'blue', 'yellow'], p_c, p_w_c.T):
print('%s\t%0.2f\t%0.2f\t%0.2f' % (k, p, p_w[0], p_w[1]))
Class P(C) P(w0|C) P(w1|C) red 0.52 0.52 0.48 blue 0.33 0.78 0.22 yellow 0.15 0.48 0.52
X,y = make_hastie_10_2(n_samples=12000,
random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(12000, 10) (12000,) [ 2.00021898 0.52680462 1.26255452 0.39019033 -1.53801338 0.85746915 0.50570651 -1.51914679 0.76832722 0.15300295]
data,rows,cols = make_biclusters(shape=(300,300),n_clusters=4,noise=10)
print(data.shape)
print(rows.shape)
print(cols.shape)
print(data[0][0])
(300, 300) (4, 300) (4, 300) -10.249833164684132
data,rows,cols = make_checkerboard(shape=(300,300),n_clusters=4,noise=10)
print(data.shape)
print(rows.shape)
print(cols.shape)
print(data[0][0])
(300, 300) (16, 300) (16, 300) 23.496944362976997
X,y = make_regression()
print(X.shape,y.shape,X[0])
(100, 100) (100,) [ 0.06698029 -0.3069998 0.77184325 -0.39840182 -0.00452168 -0.11195994 -0.14763287 0.27955819 1.02062972 -0.81439297 -0.48074648 0.23396582 0.48454019 -1.77366754 0.37189498 -0.53646237 -0.84980731 1.39237433 0.25815746 1.63455151 -1.17181136 0.33051073 1.1862697 0.73710681 -0.31791374 0.49778753 0.27869739 1.89215448 0.04803009 -0.12867303 0.21967041 -0.28971271 -0.54499742 0.79278887 0.90996164 -1.04256368 0.25586554 0.7123686 1.07949337 0.59558288 0.14479018 1.42451383 0.27289982 -0.66993241 -0.38717179 -0.36648667 -0.19179518 -0.28000574 0.34400883 -0.16284098 0.67861264 -1.550955 0.33024865 -0.16968446 -0.49826749 -1.8160245 -2.75791505 -0.32876184 -0.13179621 -1.89203641 -0.57492444 1.54479834 -0.38214558 -0.04896023 -0.23173704 -0.71714912 -0.71399436 -0.01003642 -0.50113651 -1.34137456 1.75291892 -1.01586596 -0.83445588 -0.62066657 -0.35601039 1.40640581 -0.31349628 1.30424865 0.545493 -0.16099864 0.15380927 -1.34948588 0.84204382 1.17065653 -1.15027473 -0.18641097 0.42833971 -0.21224998 2.07090812 1.33943627 1.00057381 0.18484586 0.91062904 -0.53892385 1.44622403 -0.29545586 0.15016603 -1.47529304 -0.36163744 0.88709162]
X,y = make_sparse_uncorrelated(n_samples=100,n_features=10,random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 10) (100,) [-0.5689727 0.5576709 -1.70289727 1.02627569 0.4254925 -0.71641036 -1.54054783 0.54404947 -1.63913706 0.82025732]
n_features
must be >=5; they are used to compute $y$. All other features are independent.X,y = make_friedman1(n_samples=100,n_features=10,noise=1.0,random_state=None)
print(X.shape)
print(y.shape)
(100, 10) (100,) [0.11953237 0.11602222 0.59883894 0.30855094 0.39853371 0.10051568 0.02995793 0.3369666 0.91208936 0.34997684]
X,y = make_friedman2(n_samples=100,noise=1.0,random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 4) (100,) [ 52.3916144 158.46614298 0.80609179 4.1962008 ]
X,y = make_friedman3(n_samples=100,noise=1.0,random_state=None)
print(X.shape)
print(y.shape)
print(X[0])
(100, 4) (100,) [2.50694239e-01 5.90670345e+02 4.34260321e-01 4.41140228e+00]
n_samples = 300
X, color = make_s_curve(n_samples, random_state=0)
print(X.shape)
print(color.shape)
print(X[0])
(300, 3) (300,) [ 0.44399868 1.813111 -0.10397256]
import time as time
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import make_swiss_roll
# Generate data (swiss roll dataset)
n_samples = 1500
noise = 0.05
X, _ = make_swiss_roll(n_samples, noise=noise)
st = time.time()
ward = AgglomerativeClustering(n_clusters=6, linkage='ward').fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print("Elapsed time: %.2fs" % elapsed_time)
print("Number of points: %i" % label.size)
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.view_init(5, -80)
for l in np.unique(label):
ax.scatter(X[label == l, 0], X[label == l, 1], X[label == l, 2],
color=plt.cm.jet(float(l) / np.max(label + 1)),
s=20, edgecolor='k')
plt.title('Without connectivity constraints (time %.2fs)' % elapsed_time)
Elapsed time: 0.04s Number of points: 1500
Text(0.5, 0.92, 'Without connectivity constraints (time 0.04s)')
effective_rank
.lrk = make_low_rank_matrix(n_samples=100,
n_features=100,
effective_rank=10,
tail_strength=0.5)
print(lrk.shape)
print(lrk[0])
(100, 100) [ 0.0207483 0.03501036 -0.00856394 -0.03304713 0.02422049 -0.08447866 0.07067516 0.00195111 -0.06061771 -0.0328515 -0.05058193 0.05441509 -0.01513637 -0.07151672 0.02211374 0.03362668 -0.03215901 0.01900238 0.05239907 -0.03715071 0.00914405 0.04847495 0.01704945 0.02272577 -0.01728876 0.00374982 0.05452294 0.10642534 0.00672532 0.0028429 -0.04990748 -0.01137204 0.01445119 -0.02010115 0.00411994 0.02091189 0.02687863 0.03879583 0.00268839 0.03630239 -0.00067613 -0.05864395 0.00969729 -0.01315535 -0.08337604 -0.04757325 0.03523235 0.01164664 -0.05115489 0.08647967 -0.05327722 0.04124141 0.06201153 0.02423658 0.07593028 -0.0416375 -0.05806363 -0.04389909 0.00039963 -0.14444858 -0.05349984 0.03157383 0.02038888 0.03179683 -0.00821888 -0.00894884 -0.05833372 -0.04839682 0.00062753 0.06174672 -0.00091891 0.02645161 -0.02940304 0.01017052 0.00453113 0.03049115 -0.02831156 -0.04490741 -0.05726777 0.00151238 -0.0234058 0.02274693 0.03864883 0.02520133 0.04775316 -0.01767387 0.08694662 0.06105807 -0.02852085 0.01043876 0.09791207 -0.03367347 -0.04381039 -0.0531934 -0.04371443 -0.00511209 0.03237557 -0.01823487 0.04917044 0.06563823]
n_nonzero_coefs
non-zero items): an ndarray of (#components, #samples)y,X,w = make_sparse_coded_signal(n_samples=1,
n_components=512,
n_features=100,
n_nonzero_coefs=17)
print(y.shape)
print(X.shape)
print(w.shape)
print(X[0])
(100,) (100, 512) (512,) [-1.12152746e-01 -1.74929245e-01 -4.00342334e-02 1.24385444e-01 -8.58424988e-02 -5.41659927e-02 -8.76383585e-02 6.00833930e-02 -4.64703055e-02 -2.48177145e-01 7.72228058e-03 -2.65302986e-02 -1.92818736e-02 -3.56764063e-03 5.74511332e-02 3.47340777e-02 3.09894753e-01 -1.47185558e-02 3.35152689e-02 4.90456041e-02 1.28029945e-01 1.17135017e-01 -6.34805537e-02 1.20282051e-02 3.69300807e-02 8.94991004e-02 -2.70678085e-02 6.19030831e-03 1.68784595e-01 1.25873331e-02 4.54946404e-02 1.11516911e-01 -3.44366616e-03 -6.09946106e-02 -4.61716024e-02 9.48933418e-02 1.64168667e-02 -1.61083798e-01 -5.56038920e-02 -1.16258559e-01 -1.38461026e-01 -1.09938779e-01 -7.85862413e-02 1.58726207e-01 -1.89823371e-02 5.90689571e-02 -1.81219473e-01 4.46449151e-02 5.71000374e-03 -1.66942459e-02 -5.30087870e-03 -1.25971283e-01 -1.28098834e-01 -4.98534533e-02 -1.12140671e-01 1.34747222e-01 1.03493001e-01 -8.73990484e-02 -3.17798049e-01 2.62266914e-02 4.30487741e-02 5.79193124e-02 1.51883117e-01 -1.45029528e-01 6.56137824e-02 -6.71418377e-02 1.17175934e-01 8.01116351e-02 2.15310816e-01 1.57640374e-02 -1.10512367e-01 -3.94566687e-02 2.00263022e-01 -1.53421395e-03 -1.25867738e-01 1.59817665e-03 -1.37270381e-01 9.36058546e-02 -6.30828184e-03 -2.03652060e-01 1.08801196e-01 -3.31261930e-02 1.18943436e-01 3.18144149e-02 -5.98293915e-03 1.62426966e-01 -6.00569948e-02 -3.23173306e-02 -1.22425562e-01 3.31187451e-03 2.85054240e-02 4.07034014e-02 -6.02604858e-03 -4.15831060e-02 4.29233991e-02 -4.59665752e-03 -4.57694883e-03 -1.58770153e-02 -3.79747935e-02 -1.56150716e-01 5.00542875e-02 3.42811858e-02 1.34161514e-01 -2.67465061e-02 9.94859124e-02 3.86452909e-03 7.58221949e-02 -4.63356590e-02 -4.34028184e-02 -1.30323844e-01 7.70804197e-02 -4.05952088e-02 -7.92058903e-03 3.11886080e-02 6.44481024e-02 1.37910384e-01 4.10884826e-02 1.44257220e-01 4.48294613e-03 -4.92429833e-02 -5.80847674e-02 1.10391903e-02 6.58812086e-03 -2.99988145e-02 -5.93505722e-02 7.77983832e-02 1.78261052e-01 -9.45340058e-02 -3.70020742e-02 -5.14344801e-02 3.35591162e-02 1.92223982e-02 -5.86369078e-02 -1.02309546e-01 5.74497974e-03 6.17296271e-02 3.27543371e-02 -1.45620819e-02 2.54439969e-02 -5.82988910e-02 6.43326248e-02 -2.35384173e-02 -7.53043936e-02 -8.86916742e-02 -1.85306326e-01 7.66796718e-02 2.40847945e-01 -1.11680566e-01 1.14043585e-01 8.36898962e-02 -1.48943385e-01 -1.28885837e-01 9.24653763e-03 5.83267278e-02 -2.62763160e-02 -1.81546594e-01 -1.17586097e-01 -3.54292980e-02 -1.78733907e-04 1.09132897e-01 -7.55440439e-02 -1.14266525e-02 5.07040438e-02 -1.75125381e-02 -2.09583812e-02 -8.55325525e-02 3.52956321e-03 1.18246420e-01 -8.94872711e-02 -3.57690934e-02 7.13900116e-02 -8.63524636e-03 5.82478502e-02 6.34455406e-02 -9.36210454e-02 -6.24451242e-02 -9.87173500e-02 -1.21773723e-01 1.03180635e-01 7.79522725e-02 9.29075837e-03 -4.74933489e-02 -2.97479493e-03 -6.39252414e-02 -3.65215584e-02 2.38663368e-01 -5.00486899e-02 -1.29120566e-01 2.73858918e-02 -2.69745262e-03 -9.97091404e-02 5.48431732e-02 -8.11440618e-02 -1.13541500e-01 1.96038250e-01 1.05190936e-01 -9.26783156e-02 8.87305981e-03 1.83999010e-01 -1.95931655e-01 -5.02841487e-02 -1.20843493e-01 1.99531810e-02 1.63601169e-02 -5.91593301e-02 1.00864664e-01 -8.06385244e-03 1.99727460e-02 -1.02651852e-01 -9.00080393e-02 2.91204003e-02 8.94661203e-02 1.27978848e-01 5.70688099e-02 -1.66611235e-03 -9.62447049e-02 -2.10362964e-01 -1.45261314e-01 -1.33216578e-02 -6.19364042e-02 6.65021003e-03 1.75058159e-02 2.10327698e-02 -5.41899310e-02 4.66063177e-02 2.11034862e-02 1.03955291e-01 4.84636881e-02 1.66483428e-01 1.12371930e-01 -3.41529143e-02 -5.90189676e-02 1.03155237e-02 -3.45918634e-02 9.41003307e-02 -1.08954457e-01 -3.50817547e-02 1.23837930e-01 -1.42677443e-01 -8.61781110e-02 -1.26976121e-01 6.99022143e-02 -1.76613325e-01 6.48942317e-02 4.27018864e-02 -1.25704541e-01 5.42149423e-02 1.44659020e-01 -3.37685637e-02 8.68249018e-02 -5.31366940e-02 2.16390450e-01 1.52801037e-01 -1.78674618e-01 2.79256823e-02 1.47623152e-02 -1.21732742e-01 6.12791031e-02 4.03557208e-03 4.98794886e-02 1.94722434e-01 -3.35247453e-01 3.50339295e-02 1.70584116e-01 -6.60756161e-03 4.68827007e-02 -6.16755506e-02 -6.53014544e-02 6.66614139e-02 -1.09567196e-01 6.35894079e-02 -4.26231079e-02 -3.45613844e-02 -5.39076043e-02 1.31918584e-01 2.12295545e-01 -1.35575725e-01 -5.40064542e-02 1.48238458e-01 3.90493475e-02 -2.68471765e-02 -1.41679157e-01 -2.15772307e-01 1.75074496e-02 -1.79665700e-01 -1.24104558e-01 -4.11650137e-02 3.56217578e-02 2.75081000e-03 3.59321652e-02 -1.66865883e-01 -1.50651292e-01 1.61794184e-01 6.69465639e-02 4.51707870e-03 -1.72828401e-02 -1.53120945e-01 -1.72943618e-02 -5.31790720e-02 1.46093788e-01 1.06165779e-01 4.19836914e-02 -5.09788501e-02 -1.05140229e-01 2.07022616e-02 -5.73994601e-02 -2.32468432e-01 2.54444240e-02 1.18024689e-01 -9.43618656e-02 -1.76988215e-01 1.89753915e-03 -6.19510521e-02 -2.01711401e-01 -1.43623714e-01 -6.08781649e-02 -1.02632870e-02 -9.46512218e-03 -4.70254069e-02 6.17357632e-02 -2.48380295e-01 -6.86917244e-04 1.82642108e-01 -1.71075688e-01 8.73907116e-02 5.98612816e-02 2.03161497e-01 2.28700324e-02 -1.09093891e-01 -5.29826621e-02 -1.72775826e-01 -1.03587695e-01 -1.33946189e-01 -1.29380733e-01 6.48386160e-02 1.08762106e-01 -3.97088343e-03 -3.15944286e-02 -1.47970162e-01 -2.49743679e-02 -1.30177569e-01 3.86319943e-02 -5.62926435e-02 1.04914048e-01 -5.91108694e-02 -3.30545250e-02 2.23475380e-02 -9.27318790e-02 -9.71168577e-02 -2.97454359e-02 -2.77826892e-02 1.44220420e-01 -2.80011881e-02 1.61895007e-01 9.06210226e-03 -1.54302690e-01 2.95626274e-04 5.40119951e-02 1.84852480e-01 6.23996464e-02 -1.29091438e-01 -1.64404868e-01 2.40105675e-02 9.61644414e-02 5.07868956e-03 2.16166747e-02 -1.15761880e-01 1.11681565e-01 1.28264642e-01 -1.28672790e-01 4.29105858e-04 7.34033916e-02 8.07737936e-02 4.91000957e-02 -4.04526355e-02 1.32606486e-01 -6.87191552e-02 4.70083723e-02 -6.52871395e-02 -1.66339686e-01 -1.63548149e-01 -1.06166881e-01 -1.96109376e-02 1.61015534e-02 -2.58980313e-02 1.49316213e-01 -1.08994802e-01 1.26148715e-02 9.18447361e-02 6.09393681e-02 1.18711896e-01 -3.25677456e-02 -1.30832972e-01 -8.44204993e-02 1.98921450e-02 3.03807004e-02 1.06654188e-01 -2.05643697e-01 -4.54561128e-02 -3.29435708e-02 5.78424330e-02 -7.11015025e-03 -9.97529896e-02 6.26127191e-02 8.56913836e-02 -2.90360629e-02 2.38751457e-01 -1.30245818e-01 -3.00310323e-02 -1.93990903e-01 2.69890151e-02 -6.17064694e-02 4.69728971e-02 6.35395613e-02 -4.28984339e-02 -3.27424036e-02 1.96085873e-02 -1.27977616e-01 3.27658556e-02 -9.67106931e-02 -3.76222660e-02 -5.95023251e-02 9.96649287e-02 1.09848033e-01 9.46700310e-02 -1.02157876e-02 -1.34790757e-02 1.69827505e-02 1.01332823e-01 3.57723859e-02 -1.78475435e-01 1.50062104e-01 1.14901037e-01 2.44154360e-02 -6.58376857e-02 5.13496459e-02 -7.71101950e-02 9.89177456e-02 3.60833367e-02 3.02295153e-02 -8.38904200e-03 -8.21231711e-02 2.24079917e-01 1.19183098e-01 8.99457181e-02 -7.16099443e-02 -1.27124539e-01 3.30018459e-02 -3.62143738e-02 5.00763369e-02 5.79622331e-02 1.44905379e-01 6.67587408e-02 4.10065229e-02 1.28984159e-01 -9.79342049e-03 8.80577921e-02 4.23625509e-02 1.80619716e-01 -8.12389099e-02 6.93588796e-02 1.80894134e-01 -7.21362884e-03 -7.81780486e-02 1.06494536e-02 1.16128458e-01 -1.69827030e-01 -7.07885903e-02 -3.47286907e-02 -1.15912480e-02 1.84986447e-01 2.03334450e-01 -3.15963389e-04 2.20569177e-01 -8.30259657e-02 -5.21252976e-02 -8.73717907e-02 4.61282764e-02 -2.57461092e-02 1.48657866e-01 3.72017037e-02 6.73027464e-02 -8.08827975e-02 -6.07904351e-02 -2.78645089e-02 -5.22168585e-02 -1.02011272e-01 7.73558709e-02 1.15521936e-01 1.26333318e-02 7.58466657e-03 -1.67535266e-03 -1.43971702e-02 -2.18723978e-02 -1.37850589e-02 -9.62201859e-02 -1.03572126e-01 4.42840934e-02 9.82128877e-02 1.63812250e-01 -3.87513973e-02 -6.75437674e-03 3.21570827e-02 1.56865879e-02 -2.92355260e-02 -3.68089803e-02 -5.89198975e-02 1.30577651e-01 -1.39552098e-01 2.35030491e-01 -1.95228079e-01 6.10785709e-02]
spd2x2 = make_spd_matrix(2)
spd3x3 = make_spd_matrix(3)
print(spd2x2,"\n",spd3x3)
[[2.4129143 0.12343756] [0.12343756 0.40660248]] [[ 3.40474192 -0.70747451 0.32173119] [-0.70747451 0.63463772 -0.10783239] [ 0.32173119 -0.10783239 0.77290192]]
Use Graphical Lasso
to learn covariance & sparse precision from a small #samples
To estimate a probabilistic (eg, Gaussian) model, estimating the precision (inverse covariance) matrix is as important as estimating the covariance matrix. Indeed a Gaussian model is parametrized by the precision matrix.
To be in favorable recovery conditions, we sample the data from a model with a sparse inverse covariance matrix. In addition, we ensure that the data is not too much correlated (limiting the largest coefficient of the precision matrix) and that there a no small coefficients in the precision matrix that cannot be recovered.
The #samples is slightly larger than #dimensions - thus empirical covariance is still invertible. However, the observations are strongly correlated - so the empirical covariance matrix is ill-conditioned. As a result its inverse –the empirical precision matrix– is very far from the ground truth.
If we use l2 shrinkage, as with the Ledoit-Wolf estimator, as the number of samples is small, we need to shrink a lot. As a result, the Ledoit-Wolf precision is fairly close to the ground truth precision, that is not far from being diagonal, but the off-diagonal structure is lost.
The l1-penalized estimator can recover part of this off-diagonal structure. It learns a sparse precision. It cannot recover the exact sparsity pattern: it detects too many non-zero coefficients. However, the highest non-zero coefficients of the l1 estimated correspond to the non-zero coefficients in the ground truth.
The coefficients of the l1 precision estimate are biased toward zero: because of the penalty, they are all smaller than the corresponding ground truth value, as can be seen on the figure.
The color range of the precision matrices is tweaked to improve readability of the figure. The full range of values of the empirical precision is not displayed.
GraphicalLasso alpha
(sparsity) param is set by internal cross-validation.
import numpy as np
from scipy import linalg
from sklearn.datasets import make_sparse_spd_matrix
from sklearn.covariance import GraphicalLassoCV, ledoit_wolf
import matplotlib.pyplot as plt
n_samples, n_features = 60,20
prng = np.random.RandomState(1)
prec = make_sparse_spd_matrix(n_features, alpha=.98,
smallest_coef=.4,
largest_coef=.7,
random_state=prng)
cov = linalg.inv(prec)
d = np.sqrt(np.diag(cov))
cov /= d
cov /= d[:, np.newaxis]
prec *= d
prec *= d[:, np.newaxis]
X = prng.multivariate_normal(np.zeros(n_features), cov, size=n_samples)
X -= X.mean(axis=0)
X /= X.std(axis=0)
# Estimate the covariance
emp_cov = np.dot(X.T, X) / n_samples
model = GraphicalLassoCV().fit(X)
cov_ = model.covariance_
prec_ = model.precision_
lw_cov_, _ = ledoit_wolf(X)
lw_prec_ = linalg.inv(lw_cov_)
plt.figure(figsize=(10, 6))
plt.subplots_adjust(left=0.02, right=0.98)
# plot the covariances
covs = [('Empirical', emp_cov),
('Ledoit-Wolf', lw_cov_),
('GraphicalLassoCV', cov_),
('True', cov)]
vmax = cov_.max()
for i, (name, this_cov) in enumerate(covs):
plt.subplot(2, 4, i + 1)
plt.imshow(this_cov, interpolation='nearest', vmin=-vmax, vmax=vmax,
cmap=plt.cm.RdBu_r)
plt.xticks(())
plt.yticks(())
plt.title('%s covariance' % name)
# plot the precisions
precs = [('Empirical', linalg.inv(emp_cov)),
('Ledoit-Wolf', lw_prec_),
('GraphicalLasso', prec_),
('True', prec)]
vmax = .9 * prec_.max()
for i, (name, this_prec) in enumerate(precs):
ax = plt.subplot(2, 4, i + 5)
plt.imshow(np.ma.masked_equal(this_prec, 0),
interpolation='nearest', vmin=-vmax, vmax=vmax,
cmap=plt.cm.RdBu_r)
plt.xticks(())
plt.yticks(())
plt.title('%s precision' % name)
if hasattr(ax, 'set_facecolor'):
ax.set_facecolor('.7')
else:
ax.set_axis_bgcolor('.7')
# plot the model selection metric
plt.figure(figsize=(4, 3))
plt.axes([.2, .15, .75, .7])
plt.plot(model.cv_results_["alphas"], model.cv_results_["mean_score"], 'o-')
plt.axvline(model.alpha_, color='.5')
plt.title('Model selection')
plt.ylabel('Cross-validation score')
plt.xlabel('alpha')
Text(0.5, 0, 'alpha')