25 #include <queso/InfoTheory.h>
27 #include <queso/Defines.h>
28 #include <gsl/gsl_sf_psi.h>
39 unsigned int dimX,
unsigned int dimY,
40 unsigned int xN,
unsigned int yN,
41 unsigned int k,
double eps )
54 for(
unsigned int i = 0; i < xN ; i++ )
56 kdTree->
annkSearch( dataX[ i ], k+1, nnIdx, nnDist, eps );
58 double my_dist = nnDist[
k ];
66 kdTree->
annkSearch( dataX[ i ], yN, nnIdx_tmp, nnDist_tmp, eps );
68 for(
unsigned int my_k = k + 1; my_k < yN; ++my_k )
69 if( nnDist_tmp[ my_k ] > 0.0 )
71 my_dist = nnDist_tmp[ my_k ];
78 distsXY[ i ] = my_dist;
95 void normalizeANN_XY(
ANNpointArray dataXY,
unsigned int dimXY,
107 for(
unsigned int i = 0; i < N; i++ ) {
108 for(
unsigned int j = 0; j < dimX; j++ ) {
109 meanXY[ j ] += dataXY[ i ][ j ];
111 for(
unsigned int j = 0; j < dimY; j++ ) {
112 meanXY[ dimX + j ] += dataXY[ i ][ dimX + j ];
115 for(
unsigned int j = 0; j < dimXY; j++ ) {
116 meanXY[ j ] = meanXY[ j ] / (double)N;
120 for(
unsigned int i = 0; i < N; i++ ) {
121 for(
unsigned int j = 0; j < dimXY; j++ ) {
122 stdXY[ j ] += pow( dataXY[ i ][ j ] - meanXY[ j ], 2.0 );
125 for(
unsigned int j = 0; j < dimXY; j++ ) {
126 stdXY[ j ] = sqrt( stdXY[ j ] / ((
double)N-1.0) );
138 for(
unsigned int i = 0; i < N; i++ ) {
140 for(
unsigned int j = 0; j < dimXY; j++ ) {
141 dataXY[ i ][ j ] = ( dataXY[ i ][ j ] - meanXY[ j ] ) / stdXY[ j ];
145 for(
unsigned int j = 0; j < dimX; j++ ) {
146 dataX[ i ][ j ] = dataXY[ i ][ j ];
148 for(
unsigned int j = 0; j < dimY; j++ ) {
149 dataY[ i ][ j ] = dataXY[ i ][ dimX + j ];
159 unsigned int dimX,
unsigned int dimY,
160 unsigned int k,
unsigned int N,
double eps )
169 unsigned int dimXY = dimX + dimY;
174 distsXY =
new double[N];
177 normalizeANN_XY( dataXY, dimXY, dataX, dimX, dataY, dimY, N);
182 distANN_XY( dataXY, dataXY, distsXY, dimXY, dimXY, N, N, k, eps );
185 double marginal_contrib = 0.0;
186 for(
unsigned int i = 0; i < N; i++ ) {
188 int no_pts_X = kdTreeX->
annkFRSearch( dataX[ i ], distsXY[ i ], 0, NULL, NULL, eps);
189 int no_pts_Y = kdTreeY->
annkFRSearch( dataY[ i ], distsXY[ i ], 0, NULL, NULL, eps);
191 marginal_contrib += gsl_sf_psi_int( no_pts_X+1 ) + gsl_sf_psi_int( no_pts_Y+1 );
193 MI_est = gsl_sf_psi_int( k ) + gsl_sf_psi_int( N ) - marginal_contrib / (double)N;
210 template<
template <
class P_V,
class P_M>
class RV,
class P_V,
class P_M>
211 double estimateMI_ANN(
const RV<P_V,P_M>& jointRV,
212 const unsigned int xDimSel[],
unsigned int dimX,
213 const unsigned int yDimSel[],
unsigned int dimY,
214 unsigned int k,
unsigned int N,
double eps )
219 unsigned int dimXY = dimX + dimY;
225 P_V smpRV( jointRV.imageSet().vectorSpace().zeroVector() );
226 for(
unsigned int i = 0; i < N; i++ ) {
228 jointRV.realizer().realization( smpRV );
231 for(
unsigned int j = 0; j < dimX; j++ ) {
232 dataXY[ i ][ j ] = smpRV[ xDimSel[j] ];
234 for(
unsigned int j = 0; j < dimY; j++ ) {
235 dataXY[ i ][ dimX + j ] = smpRV[ yDimSel[j] ];
240 MI_est = computeMI_ANN( dataXY,
254 template<
class P_V,
class P_M,
255 template <
class P_V,
class P_M>
class RV_1,
256 template <
class P_V,
class P_M>
class RV_2>
257 double estimateMI_ANN(
const RV_1<P_V,P_M>& xRV,
258 const RV_2<P_V,P_M>& yRV,
259 const unsigned int xDimSel[],
unsigned int dimX,
260 const unsigned int yDimSel[],
unsigned int dimY,
261 unsigned int k,
unsigned int N,
double eps )
266 unsigned int dimXY = dimX + dimY;
272 P_V smpRV_x( xRV.imageSet().vectorSpace().zeroVector() );
273 P_V smpRV_y( yRV.imageSet().vectorSpace().zeroVector() );
275 for(
unsigned int i = 0; i < N; i++ ) {
277 xRV.realizer().realization( smpRV_x );
278 yRV.realizer().realization( smpRV_y );
281 for(
unsigned int j = 0; j < dimX; j++ ) {
282 dataXY[ i ][ j ] = smpRV_x[ xDimSel[j] ];
284 for(
unsigned int j = 0; j < dimY; j++ ) {
285 dataXY[ i ][ dimX + j ] = smpRV_y[ yDimSel[j] ];
290 MI_est = computeMI_ANN( dataXY,
304 template <
class P_V,
class P_M,
305 template <
class P_V,
class P_M>
class RV_1,
306 template <
class P_V,
class P_M>
class RV_2>
307 double estimateKL_ANN( RV_1<P_V,P_M>& xRV,
309 unsigned int xDimSel[],
unsigned int dimX,
310 unsigned int yDimSel[],
unsigned int dimY,
311 unsigned int xN,
unsigned int yN,
312 unsigned int k,
double eps )
328 distsX =
new double[xN];
329 distsXY =
new double[xN];
332 P_V xSmpRV( xRV.imageSet().vectorSpace().zeroVector() );
333 for(
unsigned int i = 0; i < xN; i++ ) {
335 xRV.realizer().realization( xSmpRV );
337 for(
unsigned int j = 0; j < dimX; j++ ) {
338 dataX[ i ][ j ] = xSmpRV[ xDimSel[j] ];
343 P_V ySmpRV( yRV.imageSet().vectorSpace().zeroVector() );
344 for(
unsigned int i = 0; i < yN; i++ ) {
346 yRV.realizer().realization( ySmpRV );
348 for(
unsigned int j = 0; j < dimY; j++ ) {
349 dataY[ i ][ j ] = ySmpRV[ yDimSel[j] ];
354 distANN_XY( dataX, dataX, distsX, dimX, dimX, xN, xN, k+1, eps );
355 distANN_XY( dataX, dataY, distsXY, dimX, dimY, xN, yN, k, eps );
358 double sum_log_ratio = 0.0;
359 for(
unsigned int i = 0; i < xN; i++ )
361 sum_log_ratio += log( distsXY[i] / distsX[i] );
363 KL_est = (double)dimX/(
double)xN * sum_log_ratio + log( (
double)yN / ((
double)xN-1.0 ) );
379 template <
class P_V,
class P_M,
380 template <
class P_V,
class P_M>
class RV_1,
381 template <
class P_V,
class P_M>
class RV_2>
382 double estimateCE_ANN( RV_1<P_V,P_M>& xRV,
384 unsigned int xDimSel[],
unsigned int dimX,
385 unsigned int yDimSel[],
unsigned int dimY,
386 unsigned int xN,
unsigned int yN,
387 unsigned int k,
double eps )
403 distsXY =
new double[xN];
407 P_V xSmpRV( xRV.imageSet().vectorSpace().zeroVector() );
408 for(
unsigned int i = 0; i < xN; i++ ) {
410 xRV.realizer().realization( xSmpRV );
412 for(
unsigned int j = 0; j < dimX; j++ ) {
413 dataX[ i ][ j ] = xSmpRV[ xDimSel[j] ];
418 P_V ySmpRV( yRV.imageSet().vectorSpace().zeroVector() );
419 for(
unsigned int i = 0; i < yN; i++ ) {
421 yRV.realizer().realization( ySmpRV );
423 for(
unsigned int j = 0; j < dimY; j++ ) {
424 dataY[ i ][ j ] = ySmpRV[ yDimSel[j] ];
429 distANN_XY( dataX, dataY, distsXY, dimX, dimY, xN, yN, k, eps );
433 double sum_log = 0.0;
434 for(
unsigned int i = 0; i < xN; i++ )
436 sum_log += log( distsXY[i] );
438 CE_est = (double)dimX/(
double)xN * sum_log + log( (
double)yN ) - gsl_sf_psi_int( k );
450 #endif // QUESO_HAS_ANN
DLL_API void annDeallocPts(ANNpointArray &pa)
#define queso_error_msg(msg)
int annkFRSearch(ANNpoint q, ANNdist sqRad, int k, ANNidxArray nn_idx=NULL, ANNdistArray dd=NULL, double eps=0.0)
void annkSearch(ANNpoint q, int k, ANNidxArray nn_idx, ANNdistArray dd, double eps=0.0)
DLL_API ANNpointArray annAllocPts(int n, int dim)
DLL_API ANNpoint annAllocPt(int dim, ANNcoord c=0)