50 using namespace internal;
52 using std::unique_ptr;
58 void init_statistic_job();
59 void init_permutation_job();
60 void init_variance_h1_job();
99 static constexpr
bool DEFAULT_PRECOMPUTE =
true;
100 static constexpr
index_t DEFAULT_NUM_EIGENVALUES = 10;
112 REQUIRE(owner.get_num_samples_p()>0,
113 "Number of samples from P (was %s) has to be > 0!\n", owner.get_num_samples_p());
114 REQUIRE(owner.get_num_samples_q()>0,
115 "Number of samples from Q (was %s) has to be > 0!\n", owner.get_num_samples_q());
117 statistic_job.m_n_x=owner.get_num_samples_p();
118 statistic_job.m_n_y=owner.get_num_samples_q();
119 statistic_job.m_stype=owner.get_statistic_type();
124 REQUIRE(owner.get_num_samples_p()>0,
125 "Number of samples from P (was %s) has to be > 0!\n", owner.get_num_samples_p());
126 REQUIRE(owner.get_num_samples_q()>0,
127 "Number of samples from Q (was %s) has to be > 0!\n", owner.get_num_samples_q());
129 variance_h1_job.m_n_x=owner.get_num_samples_p();
130 variance_h1_job.m_n_y=owner.get_num_samples_q();
135 REQUIRE(owner.get_num_samples_p()>0,
136 "Number of samples from P (was %s) has to be > 0!\n", owner.get_num_samples_p());
137 REQUIRE(owner.get_num_samples_q()>0,
138 "Number of samples from Q (was %s) has to be > 0!\n", owner.get_num_samples_q());
139 REQUIRE(owner.get_num_null_samples()>0,
140 "Number of null samples (was %d) has to be > 0!\n", owner.get_num_null_samples());
142 permutation_job.m_n_x=owner.get_num_samples_p();
143 permutation_job.m_n_y=owner.get_num_samples_q();
144 permutation_job.m_stype=owner.get_statistic_type();
145 permutation_job.m_num_null_samples=owner.get_num_null_samples();
150 ASSERT(owner.get_kernel());
151 if (!is_kernel_initialized)
154 auto samples_p_and_q=owner.get_p_and_q();
156 auto kernel=owner.get_kernel();
157 kernel->init(samples_p_and_q, samples_p_and_q);
158 is_kernel_initialized=
true;
159 SG_SINFO(
"Kernel is initialized with joint features of %d total samples!\n", samples_p_and_q->get_num_vectors());
166 ASSERT(owner.get_kernel());
167 ASSERT(is_kernel_initialized);
169 if (owner.get_kernel()->get_kernel_type()!=
K_CUSTOM)
171 auto kernel=owner.get_kernel();
172 owner.get_kernel_mgr().precompute_kernel_at(0);
173 kernel->remove_lhs_and_rhs();
177 auto precomputed_kernel=
static_cast<CCustomKernel*
>(owner.get_kernel());
181 CQuadraticTimeMMD::CQuadraticTimeMMD() :
CMMD()
191 void CQuadraticTimeMMD::init()
193 self=unique_ptr<Self>(
new Self(*
this));
204 if (samples_from_p!=
get_p())
208 self->is_kernel_initialized=
false;
209 self->multi_kernel->invalidate_precomputed_distance();
213 SG_WARNING(
"Existing kernel is already precomputed. Features provided will be\ 214 ignored unless the kernel is updated with a non-precomputed one!\n");
215 self->is_kernel_initialized=
true;
220 SG_INFO(
"Provided features are the same as the existing one. Ignoring!\n");
226 if (samples_from_q!=
get_q())
230 self->is_kernel_initialized=
false;
231 self->multi_kernel->invalidate_precomputed_distance();
235 SG_WARNING(
"Existing kernel is already precomputed. Features provided will be\ 236 ignored unless the kernel is updated with a non-precomputed one!\n");
237 self->is_kernel_initialized=
true;
242 SG_INFO(
"Provided features are the same as the existing one. Ignoring!\n");
254 auto samples=data_mgr.next();
255 if (!samples.empty())
267 return samples_p_and_q;
276 self->is_kernel_initialized=
false;
280 SG_INFO(
"Setting a precomputed kernel. Features provided will be ignored!\n");
281 self->is_kernel_initialized=
true;
286 SG_INFO(
"Provided kernel is the same as the existing one. Ignoring!\n");
293 self->is_kernel_initialized=
false;
298 SG_WARNING(
"Selected kernel is already precomputed. Features provided will be\ 299 ignored unless the kernel is updated with a non-precomputed one!\n");
300 self->is_kernel_initialized=
true;
308 return Nx*Ny*statistic/(Nx+Ny);
317 self->init_statistic_job();
321 if (self->precompute)
324 statistic=
self->statistic_job(kernel_matrix);
329 if (kernel->get_kernel_type()==
K_CUSTOM)
330 SG_INFO(
"Precompute is turned off, but provided kernel is already precomputed!\n");
331 auto kernel_functor=internal::Kernel(kernel);
332 statistic=
self->statistic_job(kernel_functor);
358 if (kernel->get_kernel_type()==
K_CUSTOM)
359 SG_SINFO(
"Precompute is turned off, but provided kernel is already precomputed!\n");
360 auto kernel_functor=internal::Kernel(kernel);
365 for (
auto i=0; i<result.
vlen; ++i)
375 REQUIRE(owner.get_kernel(),
"Kernel is not set!\n");
376 REQUIRE(precompute,
"MMD2_SPECTRUM is not possible without precomputing the kernel matrix!\n");
378 index_t m=owner.get_num_samples_p();
379 index_t n=owner.get_num_samples_q();
381 REQUIRE(num_eigenvalues>0 && num_eigenvalues<m+n-1,
382 "Number of Eigenvalues (%d) must be in between [1, %d]\n", num_eigenvalues, m+n-1);
391 std::copy(kernel_matrix.
data(), kernel_matrix.
data()+kernel_matrix.
size(), K.
data());
398 Eigen::SelfAdjointEigenSolver<Eigen::MatrixXf> eigen_solver(c_kernel_matrix);
399 REQUIRE(eigen_solver.info()==Eigen::Success,
"Eigendecomposition failed!\n");
400 index_t max_num_eigenvalues=eigen_solver.eigenvalues().rows();
405 for (
auto i=0; i<null_samples.vlen; ++i)
408 for (
index_t j=0; j<num_eigenvalues; ++j)
414 float64_t eigenvalue_estimate=eigen_solver.eigenvalues()[max_num_eigenvalues-1-j];
415 eigenvalue_estimate/=(m+n);
420 null_sample+=eigenvalue_estimate*multiple;
422 null_samples[i]=null_sample;
433 REQUIRE(owner.get_kernel(),
"Kernel is not set!\n");
434 REQUIRE(precompute,
"MMD2_GAMMA is not possible without precomputing the kernel matrix!\n");
437 index_t m=owner.get_num_samples_p();
438 index_t n=owner.get_num_samples_q();
439 REQUIRE(m==n,
"Number of samples from p (%d) and q (%d) must be equal.\n", n, m)
460 mean_mmd+=kernel_matrix(i, m+i);
462 mean_mmd=2.0/m*(1.0-1.0/m*mean_mmd);
475 if (i==j || m+i==j || m+j==i)
479 to_add+=kernel_matrix(m+i, m+j);
480 to_add-=kernel_matrix(i, m+j);
481 to_add-=kernel_matrix(m+i, j);
482 var_mmd+=CMath::pow(to_add, 2);
486 var_mmd*=2.0/m/(m-1)*1.0/m/(m-1);
489 float64_t a=CMath::pow(mean_mmd, 2)/var_mmd;
501 REQUIRE(get_kernel(),
"Kernel is not set!\n");
503 "Computing variance estimate is not possible without precomputing the kernel matrix!\n");
507 return self->variance_h0_job(kernel_matrix);
512 REQUIRE(get_kernel(),
"Kernel is not set!\n");
514 self->init_variance_h1_job();
516 if (self->precompute)
519 variance_estimate=
self->variance_h1_job(kernel_matrix);
523 auto kernel=get_kernel();
524 if (kernel->get_kernel_type()==
K_CUSTOM)
525 SG_INFO(
"Precompute is turned off, but provided kernel is already precomputed!\n");
526 auto kernel_functor=internal::Kernel(kernel);
527 variance_estimate=
self->variance_h1_job(kernel_functor);
529 return variance_estimate;
534 REQUIRE(get_kernel(),
"Kernel is not set!\n");
536 switch (get_null_approximation_method())
541 result=CStatistics::gamma_cdf(statistic, params[0], params[1]);
545 result=CHypothesisTest::compute_p_value(statistic);
553 REQUIRE(get_kernel(),
"Kernel is not set!\n");
555 switch (get_null_approximation_method())
560 result=CStatistics::gamma_inverse_cdf(alpha, params[0], params[1]);
564 result=CHypothesisTest::compute_threshold(alpha);
572 REQUIRE(get_kernel(),
"Kernel is not set!\n");
574 switch (get_null_approximation_method())
577 null_samples=
self->sample_null_spectrum();
580 null_samples=
self->sample_null_permutation();
589 return self->multi_kernel.
get();
592 void CQuadraticTimeMMD::spectrum_set_num_eigenvalues(
index_t num_eigenvalues)
594 self->num_eigenvalues=num_eigenvalues;
597 index_t CQuadraticTimeMMD::spectrum_get_num_eigenvalues()
const 599 return self->num_eigenvalues;
602 void CQuadraticTimeMMD::precompute_kernel_matrix(
bool precompute)
604 if (self->precompute && !precompute)
608 get_kernel_mgr().restore_kernel_at(0);
609 self->is_kernel_initialized=
false;
610 if (get_kernel()->get_kernel_type()==
K_CUSTOM)
612 SG_WARNING(
"The existing kernel itself is a precomputed kernel!\n");
614 self->is_kernel_initialized=
true;
618 self->precompute=precompute;
621 void CQuadraticTimeMMD::save_permutation_inds(
bool save_inds)
623 self->permutation_job.m_save_inds=save_inds;
628 return self->permutation_job.m_all_inds;
631 const char* CQuadraticTimeMMD::get_name()
const 633 return "QuadraticTimeMMD";
virtual ~CQuadraticTimeMMD()
The Custom Kernel allows for custom user provided kernel matrices.
void init_permutation_job()
virtual float64_t compute_statistic()
SGMatrix< float32_t > get_float32_kernel_matrix()
virtual void select_kernel()
const index_t get_num_samples_q() const
virtual float64_t normalize_statistic(float64_t statistic) const
static constexpr index_t DEFAULT_NUM_EIGENVALUES
virtual void set_p(CFeatures *samples_from_p)
CFeatures * get_p() const
SGVector< float64_t > sample_null_spectrum()
virtual void set_q(CFeatures *samples_from_q)
virtual CFeatures * create_merged_copy(CList *others)
unique_ptr< CMultiKernelQuadraticTimeMMD > multi_kernel
This class implements the quadratic time Maximum Mean Statistic as described in [1]. The MMD is the distance of two probability distributions and in a RKHS which we denote by .
virtual void set_p(CFeatures *samples_from_p)
bool is_kernel_initialized
SGVector< float64_t > gamma_fit_null()
friend class CMultiKernelQuadraticTimeMMD
CFeatures * get_q() const
VarianceH0 variance_h0_job
void init_variance_h1_job()
PermutationMMD permutation_job
void init_statistic_job()
internal::DataManager & get_data_mgr()
VarianceH1 variance_h1_job
const index_t get_num_samples_p() const
Class that performs quadratic time MMD test optimized for multiple shift-invariant kernels...
internal::KernelManager & get_kernel_mgr()
all of classes and functions are contained in the shogun namespace
virtual EKernelType get_kernel_type()=0
T get(const Tag< T > &_tag) const
CQuadraticTimeMMD & owner
The class Features is the base class of all feature objects.
CFeatures * get_p_and_q()
CKernel * get_kernel() const
Class DataManager for fetching/streaming test data block-wise. It can handle data coming from multipl...
Abstract base class that provides an interface for performing kernel two-sample test using Maximum Me...
SGMatrix< float32_t > get_kernel_matrix()
virtual void set_kernel(CKernel *kernel)
SGVector< float64_t > sample_null_permutation()
virtual void set_kernel(CKernel *kernel)
Self(CQuadraticTimeMMD &)
virtual void select_kernel()
static constexpr bool DEFAULT_PRECOMPUTE
virtual void set_q(CFeatures *samples_from_q)