]> git.uio.no Git - u/mrichter/AliRoot.git/blob - STAT/TKDInterpolatorBase.cxx
TKDTree class now in ROOT
[u/mrichter/AliRoot.git] / STAT / TKDInterpolatorBase.cxx
1 #include "TKDInterpolatorBase.h"
2 #include "TKDNodeInfo.h"
3 #include "TKDTree.h"
4
5 #include "TClonesArray.h"
6 #include "TLinearFitter.h"
7 #include "TTree.h"
8 #include "TH2.h"
9 #include "TObjArray.h"
10 #include "TObjString.h"
11 #include "TBox.h"
12 #include "TGraph.h"
13 #include "TMarker.h"
14 #include "TMath.h"
15 #include "TVectorD.h"
16 #include "TMatrixD.h"
17
18 ClassImp(TKDInterpolatorBase)
19
20 /////////////////////////////////////////////////////////////////////
21 // Memory setup of protected data members
22 // fRefPoints : evaluation point of PDF for each terminal node of underlying KD Tree.
23 // | 1st terminal node (fNDim point coordinates) | 2nd terminal node (fNDim point coordinates) | ...
24 //
25 // fRefValues : evaluation value/error of PDF for each terminal node of underlying KD Tree.
26 // | 1st terminal node (value) | 2nd terminal node (value) | ... | 1st terminal node (error) | 2nd terminal node (error) | ...
27 //
28 // status = |0|0|0|0|0|1(tri-cubic weights)|1(STORE)|1 INT(0 COG )|
29 /////////////////////////////////////////////////////////////////////
30
31
32 //_________________________________________________________________
33 TKDInterpolatorBase::TKDInterpolatorBase(Int_t dim) :
34   fNSize(dim)
35   ,fNTNodes(0)
36   ,fTNodes(0x0)
37   ,fStatus(4)
38   ,fLambda(1 + dim + (dim*(dim+1)>>1))
39   ,fDepth(-1)
40   ,fAlpha(.5)
41   ,fRefPoints(0x0)
42   ,fBuffer(0x0)
43   ,fKDhelper(0x0)
44   ,fFitter(0x0)
45 {
46 // Default constructor. To be used with care since in this case building
47 // of data structure is completly left to the user responsability.
48 }
49
50 //_________________________________________________________________
51 void    TKDInterpolatorBase::Build(Int_t n)
52 {
53   // allocate memory for data
54
55   if(fTNodes) delete fTNodes;
56   fNTNodes = n;
57   fTNodes = new TClonesArray("TKDNodeInfo", fNTNodes);
58   for(int in=0; in<fNTNodes; in++) new ((*fTNodes)[in]) TKDNodeInfo(fNSize);
59 }
60
61 //_________________________________________________________________
62 TKDInterpolatorBase::~TKDInterpolatorBase()
63 {
64   if(fFitter) delete fFitter;
65   if(fKDhelper) delete fKDhelper;
66   if(fBuffer) delete [] fBuffer;
67   
68   if(fRefPoints){
69     for(int idim=0; idim<fNSize; idim++) delete [] fRefPoints[idim] ;
70     delete [] fRefPoints;
71   }
72   if(fTNodes) delete fTNodes;
73 }
74
75
76 //__________________________________________________________________
77 Bool_t  TKDInterpolatorBase::GetCOGPoint(Int_t inode, Float_t *&coord, Float_t &val, Float_t &err) const
78 {
79   if(inode < 0 || inode > fNTNodes) return kFALSE;
80
81   TKDNodeInfo *node = (TKDNodeInfo*)(*fTNodes)[inode];
82   coord = &(node->Data()[0]);
83   val = node->Val()[0];
84   err = node->Val()[1];
85   return kTRUE;
86 }
87
88 //_________________________________________________________________
89 TKDNodeInfo* TKDInterpolatorBase::GetNodeInfo(Int_t inode) const
90 {
91   if(!fTNodes || inode >= fNTNodes) return 0x0;
92   return (TKDNodeInfo*)(*fTNodes)[inode];
93 }
94
95
96 //__________________________________________________________________
97 void TKDInterpolatorBase::GetStatus()
98 {
99 // Prints the status of the interpolator
100
101   printf("Interpolator Status :\n");
102   printf("  Dim    : %d [%d]\n", fNSize, fLambda);
103   printf("  Method : %s\n", fStatus&1 ? "INT" : "COG");
104   printf("  Store  : %s\n", fStatus&2 ? "YES" : "NO");
105   printf("  Weights: %s\n", fStatus&4 ? "YES" : "NO");
106   
107   printf("fNTNodes %d\n", fNTNodes);        //Number of evaluation data points
108   for(int i=0; i<fNTNodes; i++){
109     TKDNodeInfo *node = (TKDNodeInfo*)(*fTNodes)[i]; 
110     printf("%d ", i); node->Print();
111   }
112 }
113
114 //_________________________________________________________________
115 Double_t TKDInterpolatorBase::Eval(const Double_t *point, Double_t &result, Double_t &error, Bool_t force)
116 {
117 // Evaluate PDF for "point". The result is returned in "result" and error in "error". The function returns the chi2 of the fit.
118 //
119 // Observations:
120 //
121 // 1. The default method used for interpolation is kCOG.
122 // 2. The initial number of neighbors used for the estimation is set to Int(alpha*fLambda) (alpha = 1.5)
123       
124   Float_t pointF[50]; // local Float_t conversion for "point"
125   for(int idim=0; idim<fNSize; idim++) pointF[idim] = (Float_t)point[idim];
126   Int_t nodeIndex = GetNodeIndex(pointF);
127   if(nodeIndex<0){
128     result = 0.;
129     error = 1.E10;
130     return 0.;
131   }
132   TKDNodeInfo *node = (TKDNodeInfo*)(*fTNodes)[nodeIndex];
133   if((fStatus&1) && node->Cov() && !force) return node->CookPDF(point, result, error);
134
135   // Allocate memory
136   if(!fBuffer) fBuffer = new Double_t[2*fLambda];
137   if(!fKDhelper){ 
138     fRefPoints = new Float_t*[fNSize];
139     for(int id=0; id<fNSize; id++){
140       fRefPoints[id] = new Float_t[fNTNodes];
141       for(int in=0; in<fNTNodes; in++) fRefPoints[id][in] = ((TKDNodeInfo*)(*fTNodes)[in])->Data()[id];
142     }
143     fKDhelper = new TKDTreeIF(fNTNodes, fNSize, 30, fRefPoints);
144     fKDhelper->MakeBoundaries();
145   }
146   if(!fFitter) fFitter = new TLinearFitter(fLambda, Form("hyp%d", fLambda-1));
147   
148   // generate parabolic for nD
149   //Float_t alpha = Float_t(2*lambda + 1) / fNTNodes; // the bandwidth or smoothing parameter
150   //Int_t npoints = Int_t(alpha * fNTNodes);
151   //printf("Params : %d NPoints %d\n", lambda, npoints);
152   // prepare workers
153
154   Int_t ipar,    // local looping variable
155         npoints = Int_t((1.+fAlpha)*fLambda); // number of data points used for interpolation
156   Int_t *index = new Int_t[2*npoints];  // indexes of NN 
157   Float_t *dist = new Float_t[2*npoints], // distances of NN
158           d,     // NN normalized distance
159           w0,    // work
160           w;     // tri-cubic weight function
161
162   do{
163     // find nearest neighbors
164     for(int idim=0; idim<fNSize; idim++) pointF[idim] = (Float_t)point[idim];
165     fKDhelper->FindNearestNeighbors(pointF, npoints+1, index, dist);
166     // add points to fitter
167     fFitter->ClearPoints();
168     TKDNodeInfo *tnode = 0x0;
169     for(int in=0; in<npoints; in++){
170       tnode = (TKDNodeInfo*)(*fTNodes)[index[in]];
171       //tnode->Print();
172       if(fStatus&1){ // INT
173         Float_t *bounds = &(tnode->Data()[fNSize]);
174         ipar = 0;
175         for(int idim=0; idim<fNSize; idim++){
176           fBuffer[ipar++] = .5*(bounds[2*idim] + bounds[2*idim+1]);
177           fBuffer[ipar++] = (bounds[2*idim]*bounds[2*idim] + bounds[2*idim] * bounds[2*idim+1] + bounds[2*idim+1] * bounds[2*idim+1])/3.;
178           for(int jdim=idim+1; jdim<fNSize; jdim++) fBuffer[ipar++] = (bounds[2*idim] + bounds[2*idim+1]) * (bounds[2*jdim] + bounds[2*jdim+1]) * .25;
179         }
180       } else { // COG
181         Float_t *p = &(tnode->Data()[0]);
182         ipar = 0;
183         for(int idim=0; idim<fNSize; idim++){
184           fBuffer[ipar++] = p[idim];
185           for(int jdim=idim; jdim<fNSize; jdim++) fBuffer[ipar++] = p[idim]*p[jdim];
186         }
187       }
188
189       // calculate tri-cubic weighting function
190       if(fStatus&4){
191         d = dist[in]/ dist[npoints];
192         w0 = (1. - d*d*d); w = w0*w0*w0;
193       } else w = 1.;
194       
195 //                      printf("x[");
196 //                      for(int idim=0; idim<fLambda-1; idim++) printf("%f ", fBuffer[idim]);
197 //                      printf("]  v[%f +- %f] (%f, %f)\n", tnode->Val()[0], tnode->Val()[1]/w, tnode->Val()[1], w);
198       fFitter->AddPoint(fBuffer, tnode->Val()[0], tnode->Val()[1]/w);
199     }
200     npoints += 4;
201   } while(fFitter->Eval());
202   delete [] index;
203   delete [] dist;
204
205   // retrive fitter results
206   TMatrixD cov(fLambda, fLambda);
207   TVectorD par(fLambda);
208   fFitter->GetCovarianceMatrix(cov);
209   fFitter->GetParameters(par);
210   Double_t chi2 = fFitter->GetChisquare()/(npoints - 4 - fLambda);
211
212   // store results
213   if(fStatus&2 && fStatus&1) node->Store(par, cov);
214     
215   // Build df/dpi|x values
216   Double_t *fdfdp = &fBuffer[fLambda];
217   ipar = 0;
218   fdfdp[ipar++] = 1.;
219   for(int idim=0; idim<fNSize; idim++){
220     fdfdp[ipar++] = point[idim];
221     for(int jdim=idim; jdim<fNSize; jdim++) fdfdp[ipar++] = point[idim]*point[jdim];
222   }
223
224   // calculate estimation
225   result =0.; error = 0.;
226   for(int i=0; i<fLambda; i++){
227     result += fdfdp[i]*par(i);
228     for(int j=0; j<fLambda; j++) error += fdfdp[i]*fdfdp[j]*cov(i,j);
229   }     
230   error = TMath::Sqrt(error);
231
232   return chi2;
233 }
234
235 //_________________________________________________________________
236 void TKDInterpolatorBase::DrawBins(UInt_t ax1, UInt_t ax2, Float_t ax1min, Float_t ax1max, Float_t ax2min, Float_t ax2max)
237 {
238 // Draw nodes structure projected on plane "ax1:ax2". The parameter
239 // "depth" specifies the bucket size per node. If depth == -1 draw only
240 // terminal nodes and evaluation points (default -1 i.e. bucket size per node equal bucket size specified by the user)
241 //
242 // Observation:
243 // This function creates the nodes (TBox) array for the specified depth
244 // but don't delete it. Abusing this function may cause memory leaks !
245
246
247   
248   TH2 *h2 = new TH2S("hNodes", "", 100, ax1min, ax1max, 100, ax2min, ax2max);
249   h2->GetXaxis()->SetTitle(Form("x_{%d}", ax1));
250   h2->GetYaxis()->SetTitle(Form("x_{%d}", ax2));
251   h2->Draw();
252   
253   const Float_t kBorder = 0.;//1.E-4;
254   TBox *boxArray = new TBox[fNTNodes], *box;
255   Float_t *bounds = 0x0;
256   for(int inode = 0; inode < fNTNodes; inode++){
257     box = &boxArray[inode];
258     box->SetFillStyle(3002);
259     box->SetFillColor(50+inode/*Int_t(gRandom->Uniform()*50.)*/);
260     
261     bounds = &(((TKDNodeInfo*)(*fTNodes)[inode])->Data()[fNSize]);
262     box->DrawBox(bounds[2*ax1]+kBorder, bounds[2*ax2]+kBorder, bounds[2*ax1+1]-kBorder, bounds[2*ax2+1]-kBorder);
263   }
264
265   // Draw reference points
266   TGraph *ref = new TGraph(fNTNodes);
267   ref->SetMarkerStyle(3);
268   ref->SetMarkerSize(.7);
269   ref->SetMarkerColor(2);
270   for(int inode = 0; inode < fNTNodes; inode++){
271     TKDNodeInfo *node = (TKDNodeInfo*)(*fTNodes)[inode];
272     ref->SetPoint(inode, node->Data()[ax1], node->Data()[ax2]);
273   }
274   ref->Draw("p");
275   return;
276 }
277
278 //__________________________________________________________________
279 void TKDInterpolatorBase::SetInterpolationMethod(Bool_t on)
280 {
281 // Set interpolation bit to "on".
282   
283   if(on) fStatus += fStatus&1 ? 0 : 1;
284   else fStatus += fStatus&1 ? -1 : 0;
285 }
286
287
288 //_________________________________________________________________
289 void TKDInterpolatorBase::SetStore(Bool_t on)
290 {
291 // Set store bit to "on"
292   
293   if(on) fStatus += fStatus&2 ? 0 : 2;
294   else fStatus += fStatus&2 ? -2 : 0;
295 }
296
297 //_________________________________________________________________
298 void TKDInterpolatorBase::SetWeights(Bool_t on)
299 {
300 // Set weights bit to "on"
301   
302   if(on) fStatus += fStatus&4 ? 0 : 4;
303   else fStatus += fStatus&4 ? -4 : 0;
304 }