Source code for pmlearn.mixture.tests.test_dirichlet_process

import unittest
import shutil
import tempfile

import numpy as np
# import pandas as pd
# import pymc3 as pm
# from pymc3 import summary
# from sklearn.mixture import BayesianGaussianMixture as skBayesianGaussianMixture
from sklearn.model_selection import train_test_split

from pmlearn.exceptions import NotFittedError
from pmlearn.mixture import DirichletProcessMixture


[docs]class DirichletProcessMixtureTestCase(unittest.TestCase):
[docs] def setUp(self): self.num_truncate = 3 self.num_components = 3 self.num_pred = 1 self.num_training_samples = 100 self.pi = np.array([0.35, 0.4, 0.25]) self.means = np.array([0, 5, 10]) self.sigmas = np.array([0.5, 0.5, 1.0]) self.components = np.random.randint(0, self.num_components, self.num_training_samples) X = np.random.normal(loc=self.means[self.components], scale=self.sigmas[self.components]) X.shape = (self.num_training_samples, 1) self.X_train, self.X_test = train_test_split(X, test_size=0.3) self.test_DPMM = DirichletProcessMixture() self.test_nuts_DPMM = DirichletProcessMixture() self.test_dir = tempfile.mkdtemp()
[docs] def tearDown(self): shutil.rmtree(self.test_dir)
# class DirichletProcessMixtureFitTestCase(DirichletProcessMixtureTestCase): # def test_advi_fit_returns_correct_model(self): # # This print statement ensures PyMC3 output won't overwrite the test name # print('') # self.test_DPMM.fit(self.X_train) # # self.assertEqual(self.num_pred, self.test_DPMM.num_pred) # self.assertEqual(self.num_components, self.test_DPMM.num_components) # self.assertEqual(self.num_truncate, self.test_DPMM.num_truncate) # # self.assertAlmostEqual(self.pi[0], # self.test_DPMM.summary['mean']['pi__0'], # 0) # self.assertAlmostEqual(self.pi[1], # self.test_DPMM.summary['mean']['pi__1'], # 0) # self.assertAlmostEqual(self.pi[2], # self.test_DPMM.summary['mean']['pi__2'], # 0) # # self.assertAlmostEqual( # self.means[0], # self.test_DPMM.summary['mean']['cluster_center_0__0'], # 0) # self.assertAlmostEqual( # self.means[1], # self.test_DPMM.summary['mean']['cluster_center_1__0'], # 0) # self.assertAlmostEqual( # self.means[2], # self.test_DPMM.summary['mean']['cluster_center_2__0'], # 0) # # self.assertAlmostEqual( # self.sigmas[0], # self.test_DPMM.summary['mean']['cluster_variance_0__0'], # 0) # self.assertAlmostEqual( # self.sigmas[1], # self.test_DPMM.summary['mean']['cluster_variance_1__0'], # 0) # self.assertAlmostEqual( # self.sigmas[2], # self.test_DPMM.summary['mean']['cluster_variance_2__0'], # 0) # # def test_nuts_fit_returns_correct_model(self): # # This print statement ensures PyMC3 output won't overwrite the test name # print('') # self.test_nuts_DPMM.fit(self.X_train, # inference_type='nuts', # inference_args={'draws': 1000, # 'chains': 2}) # # self.assertEqual(self.num_pred, self.test_nuts_DPMM.num_pred) # self.assertEqual(self.num_components, self.test_nuts_DPMM.num_components) # self.assertEqual(self.num_components, self.test_nuts_DPMM.num_truncate) # # self.assertAlmostEqual(self.pi[0], # self.test_nuts_DPMM.summary['mean']['pi__0'], # 0) # self.assertAlmostEqual(self.pi[1], # self.test_nuts_DPMM.summary['mean']['pi__1'], # 0) # self.assertAlmostEqual(self.pi[2], # self.test_nuts_DPMM.summary['mean']['pi__2'], # 0) # # self.assertAlmostEqual( # self.means[0], # self.test_nuts_DPMM.summary['mean']['cluster_center_0__0'], # 0) # self.assertAlmostEqual( # self.means[1], # self.test_nuts_DPMM.summary['mean']['cluster_center_1__0'], # 0) # self.assertAlmostEqual( # self.means[2], # self.test_nuts_DPMM.summary['mean']['cluster_center_2__0'], # 0) # # self.assertAlmostEqual( # self.sigmas[0], # self.test_nuts_DPMM.summary['mean']['cluster_variance_0__0'], # 0) # self.assertAlmostEqual( # self.sigmas[1], # self.test_nuts_DPMM.summary['mean']['cluster_variance_1__0'], # 0) # self.assertAlmostEqual( # self.sigmas[2], # self.test_nuts_DPMM.summary['mean']['cluster_variance_2__0'], # 0) # #
[docs]class DirichletProcessMixturePredictTestCase(DirichletProcessMixtureTestCase): # def test_predict_returns_predictions(self): # print('') # self.test_DPMM.fit(self.X_train, self.y_train) # preds = self.test_DPMM.predict(self.X_test) # self.assertEqual(self.y_test.shape, preds.shape) # def test_predict_returns_mean_predictions_and_std(self): # print('') # self.test_DPMM.fit(self.X_train, self.y_train) # preds, stds = self.test_DPMM.predict(self.X_test, return_std=True) # self.assertEqual(self.y_test.shape, preds.shape) # self.assertEqual(self.y_test.shape, stds.shape)
[docs] def test_predict_raises_error_if_not_fit(self): print('') with self.assertRaises(NotFittedError) as no_fit_error: test_DPMM = DirichletProcessMixture() test_DPMM.predict(self.X_train) expected = 'Run fit on the model before predict.' self.assertEqual(str(no_fit_error.exception), expected)
# class DirichletProcessMixtureScoreTestCase(DirichletProcessMixtureTestCase): # def test_score_matches_sklearn_performance(self): # print('') # skDPMM = skBayesianGaussianMixture(n_components=3) # skDPMM.fit(self.X_train) # skDPMM_score = skDPMM.score(self.X_test) # # self.test_DPMM.fit(self.X_train) # test_DPMM_score = self.test_DPMM.score(self.X_test) # # self.assertAlmostEqual(skDPMM_score, test_DPMM_score, 0) # # # class DirichletProcessMixtureSaveAndLoadTestCase(DirichletProcessMixtureTestCase): # def test_save_and_load_work_correctly(self): # print('') # self.test_DPMM.fit(self.X_train) # score1 = self.test_DPMM.score(self.X_test) # self.test_DPMM.save(self.test_dir) # # DPMM2 = DirichletProcessMixture() # DPMM2.load(self.test_dir) # # self.assertEqual(self.test_DPMM.inference_type, DPMM2.inference_type) # self.assertEqual(self.test_DPMM.num_pred, DPMM2.num_pred) # self.assertEqual(self.test_DPMM.num_training_samples, # DPMM2.num_training_samples) # self.assertEqual(self.test_DPMM.num_truncate, DPMM2.num_truncate) # # pd.testing.assert_frame_equal(summary(self.test_DPMM.trace), # summary(DPMM2.trace)) # # score2 = DPMM2.score(self.X_test) # self.assertAlmostEqual(score1, score2, 0)