First version of kdtree (Alexander, Marian)
[u/mrichter/AliRoot.git] / STAT / TKDInterpolator.cxx
1 #include "TKDInterpolator.h"
2
3 #include "TLinearFitter.h"
4 #include "TVector.h"
5 #include "TTree.h"
6 #include "TH2.h"
7 #include "TObjArray.h"
8 #include "TObjString.h"
9 #include "TBox.h"
10 #include "TGraph.h"
11 #include "TMarker.h"
12
13
14
15 ClassImp(TKDInterpolator)
16
17 /////////////////////////////////////////////////////////////////////
18 // Memory setup of protected data memebers
19 // fRefPoints : evaluation point of PDF for each terminal node of underlying KD Tree.
20 // | 1st terminal node (fNDim point coordinates) | 2nd terminal node (fNDim point coordinates) | ...
21 //
22 // fRefValues : evaluation value/error of PDF for each terminal node of underlying KD Tree.
23 // | 1st terminal node (value) | 2nd terminal node (value) | ... | 1st terminal node (error) | 2nd terminal node (error) | ...
24 /////////////////////////////////////////////////////////////////////
25
26 //_________________________________________________________________
27 TKDInterpolator::TKDInterpolator() : TKDTreeIF()
28         ,fNTNodes(0)
29         ,fRefPoints(0x0)
30         ,fRefValues(0x0)
31         ,fDepth(-1)
32         ,fTmpPoint(0x0)
33         ,fKDhelper(0x0)
34         ,fFitter(0x0)
35 {
36 }
37
38 //_________________________________________________________________
39 TKDInterpolator::TKDInterpolator(Int_t npoints, Int_t ndim, UInt_t bsize, Float_t **data) : TKDTreeIF(npoints, ndim, bsize, data)
40         ,fNTNodes(GetNTerminalNodes())
41         ,fRefPoints(0x0)
42         ,fRefValues(0x0)
43         ,fDepth(-1)
44         ,fTmpPoint(0x0)
45         ,fKDhelper(0x0)
46         ,fFitter(0x0)
47 {
48         Build();
49 }
50
51
52 //_________________________________________________________________
53 TKDInterpolator::TKDInterpolator(TTree *t, const Char_t *var, const Char_t *cut, UInt_t bsize) : TKDTreeIF()
54         ,fNTNodes(0)
55         ,fRefPoints(0x0)
56         ,fRefValues(0x0)
57         ,fDepth(-1)
58         ,fTmpPoint(0x0)
59         ,fKDhelper(0x0)
60         ,fFitter(0x0)
61 {
62 // Alocate data from a tree. The variables which have to be analysed are
63 // defined in the "var" parameter as a colon separated list. The format should
64 // be identical to that used by TTree::Draw().
65 //
66 // 
67
68         fNpoints = t->GetEntriesFast();
69         TObjArray *vars = TString(var).Tokenize(":");
70         fNDim = vars->GetEntriesFast();
71         if(fNDim > 6/*kDimMax*/) Warning("TKDInterpolator(TTree*, UInt_t, const Char_t)", Form("Variable number exceed maximum dimension %d. Results are unpredictable.", 6/*kDimMax*/));
72         fBucketSize = bsize;
73
74         printf("Allocating %d points in %d dimensions.\n", fNpoints, fNDim);
75         Float_t *mem = new Float_t[fNDim*fNpoints];
76         fData = new Float_t*[fNDim];
77         for(int idim=0; idim<fNDim; idim++) fData[idim] = &mem[idim*fNpoints];
78         kDataOwner = kTRUE;
79         
80         Double_t *v;
81         for(int idim=0; idim<fNDim; idim++){
82                 if(!(t->Draw(((TObjString*)(*vars)[idim])->GetName(), cut, "goff"))){
83                         Warning("TKDInterpolator(TTree*, UInt_t, const Char_t)", Form("Can not access data for %s", ((TObjString*)(*vars)[idim])->GetName() ));
84                         continue;
85                 }
86                 v = t->GetV1();
87                 for(int ip=0; ip<fNpoints; ip++) fData[idim][ip] = (Float_t)v[ip];
88         }
89         TKDTreeIF::Build();
90         fNTNodes = GetNTerminalNodes();
91         Build();
92 }
93
94 //_________________________________________________________________
95 TKDInterpolator::~TKDInterpolator()
96 {
97         if(fFitter) delete fFitter;
98         if(fKDhelper) delete fKDhelper;
99         if(fTmpPoint) delete [] fTmpPoint;
100         
101         if(fRefPoints){
102                 for(int idim=0; idim<fNDim; idim++) delete [] fRefPoints[idim] ;
103                 delete [] fRefPoints;
104         }
105         if(fRefValues) delete [] fRefValues;
106 }
107
108 //_________________________________________________________________
109 void TKDInterpolator::Build()
110 {
111         if(!fBoundaries) MakeBoundaries();
112         
113         // allocate memory for data
114         fRefValues = new Float_t[fNTNodes];
115         fRefPoints = new Float_t*[fNDim];
116         for(int id=0; id<fNDim; id++){
117                 fRefPoints[id] = new Float_t[fNTNodes];
118                 for(int in=0; in<fNTNodes; in++) fRefPoints[id][in] = 0.;
119         }
120
121         Float_t *bounds = 0x0;
122         Int_t *indexPoints;
123         for(int inode=0, tnode = fNnodes; inode<fNTNodes-1; inode++, tnode++){
124                 fRefValues[inode] =  Float_t(fBucketSize)/fNpoints;
125                 bounds = GetBoundary(tnode);
126                 for(int idim=0; idim<fNDim; idim++) fRefValues[inode] /= (bounds[2*idim+1] - bounds[2*idim]);
127
128                 indexPoints = GetPointsIndexes(tnode);
129                 // loop points in this terminal node
130                 for(int idim=0; idim<fNDim; idim++){
131                         for(int ip = 0; ip<fBucketSize; ip++) fRefPoints[idim][inode] += fData[idim][indexPoints[ip]];
132                         fRefPoints[idim][inode] /= fBucketSize;
133                 }
134         }
135
136         // analyze last (incomplete) terminal node
137         Int_t counts = fNpoints%fBucketSize;
138         counts = counts ? counts : fBucketSize;
139         Int_t inode = fNTNodes - 1, tnode = inode + fNnodes;
140         fRefValues[inode] =  Float_t(counts)/fNpoints;
141         bounds = GetBoundary(tnode);
142         for(int idim=0; idim<fNDim; idim++) fRefValues[inode] /= (bounds[2*idim+1] - bounds[2*idim]);
143
144         indexPoints = GetPointsIndexes(tnode);
145         // loop points in this terminal node
146         for(int idim=0; idim<fNDim; idim++){
147                 for(int ip = 0; ip<counts; ip++) fRefPoints[idim][inode] += fData[idim][indexPoints[ip]];
148                 fRefPoints[idim][inode] /= counts;
149         }
150 }
151
152 //_________________________________________________________________
153 Double_t TKDInterpolator::Eval(Float_t *point)
154 {
155
156         // calculate number of parameters in the parabolic expresion
157         Int_t kNN = 1 + fNDim + fNDim*(fNDim+1)/2;
158
159         // prepare workers
160         if(!fTmpPoint) fTmpPoint = new Double_t[fNDim];
161         if(!fKDhelper) fKDhelper = new TKDTreeIF(GetNTerminalNodes(), fNDim, kNN, fRefPoints);
162         if(!fFitter){
163                 // generate formula for nD
164                 TString formula("1");
165                 for(int idim=0; idim<fNDim; idim++){
166                         formula += Form("++x[%d]", idim);
167                         for(int jdim=idim; jdim<fNDim; jdim++) formula += Form("++x[%d]*x[%d]", idim, jdim);
168                 }
169                 fFitter = new TLinearFitter(kNN, formula.Data());
170         }
171         
172         Int_t kNN_old = 0;
173         Int_t *index;
174         Float_t dist;
175         fFitter->ClearPoints();
176         do{
177                 if(!fKDhelper->FindNearestNeighbors(point, kNN, index, dist)){
178                         Error("Eval()", Form("Failed retriving %d neighbours for point:", kNN));
179                         for(int idim=0; idim<fNDim; idim++) printf("%f ", point[idim]);
180                         printf("\n");
181                         return -1;
182                 }
183                 for(int in=kNN_old; in<kNN; in++){
184                         for(int idim=0; idim<fNDim; idim++) fTmpPoint[idim] = fRefPoints[idim][index[in]];
185                         fFitter->AddPoint(fTmpPoint, TMath::Log(fRefValues[index[in]]), 1.);
186                 }
187                 kNN_old = kNN;
188                 kNN += 4;
189         } while(fFitter->Eval());
190
191         // calculate evaluation
192         TVectorD par(kNN);
193         fFitter->GetParameters(par);
194         Double_t result = par[0];
195         Int_t ipar = 0;
196         for(int idim=0; idim<fNDim; idim++){
197                 result += par[++ipar]*point[idim];
198                 for(int jdim=idim; jdim<fNDim; jdim++) result += par[++ipar]*point[idim]*point[jdim];
199         }
200         return TMath::Exp(result);
201 }
202
203
204 //_________________________________________________________________
205 void TKDInterpolator::DrawNodes(Int_t depth, Int_t ax1, Int_t ax2)
206 {
207 // Draw nodes structure projected on plane "ax1:ax2". The parameter
208 // "depth" specifies the bucket size per node. If depth == -1 draw only
209 // terminal nodes and evaluation points (default -1 i.e. bucket size per node equal bucket size specified by the user)
210
211         if(!fBoundaries) MakeBoundaries();
212
213         // Count nodes in specific view
214         Int_t nnodes = 0;
215         for(int inode = 0; inode <= 2*fNnodes; inode++){
216                 if(depth == -1){
217                         if(!IsTerminal(inode)) continue;
218                 } else if((inode+1) >> depth != 1) continue;
219                 nnodes++;
220         }
221
222         //printf("depth %d nodes %d\n", depth, nnodes);
223         
224         //TH2 *h2 = new TH2S("hframe", "", 100, fRange[2*ax1], fRange[2*ax1+1], 100, fRange[2*ax2], fRange[2*ax2+1]);
225         TH2 *h2 = new TH2S("hframe", "", 100, 0., 1., 100, 0., 1.);
226         h2->Draw();
227         
228         const Float_t border = 0.;//1.E-4;
229         TBox **node_array = new TBox*[nnodes], *node;
230         Float_t *bounds = 0x0;
231         nnodes = 0;
232         for(int inode = 0; inode <= 2*fNnodes; inode++){
233                 if(depth == -1){
234                         if(!IsTerminal(inode)) continue;
235                 } else if((inode+1) >> depth != 1) continue;
236
237                 node = node_array[nnodes++];
238                 bounds = GetBoundary(inode);
239                 node = new TBox(bounds[2*ax1]+border, bounds[2*ax2]+border, bounds[2*ax1+1]-border, bounds[2*ax2+1]-border);
240                 node->SetFillStyle(0);  
241                 node->SetFillColor(51+inode);
242                 node->Draw();
243         }
244         if(depth != -1) return;
245
246         // Draw reference points
247         TGraph *ref = new TGraph(GetNTerminalNodes());
248         ref->SetMarkerStyle(2);
249         ref->SetMarkerColor(2);
250         Float_t val, error;
251         for(int inode = 0; inode < GetNTerminalNodes(); inode++) ref->SetPoint(inode, fRefPoints[ax1][inode], fRefPoints[ax2][inode]);
252         ref->Draw("p");
253         return;
254 }
255
256 //_________________________________________________________________
257 void TKDInterpolator::DrawNode(Int_t tnode, Int_t ax1, Int_t ax2)
258 {
259 // Draw node "node" and the data points within.
260
261         if(tnode < 0 || tnode >= GetNTerminalNodes()){
262                 Warning("DrawNode()", Form("Terminal node %d outside defined range.", tnode));
263                 return;
264         }
265
266         //TH2 *h2 = new TH2S("hframe", "", 1, fRange[2*ax1], fRange[2*ax1+1], 1, fRange[2*ax2], fRange[2*ax2+1]);
267         TH2 *h2 = new TH2S("hframe", "", 1, 0., 1., 1, 0., 1.);
268         h2->Draw();
269
270         Int_t inode = tnode;
271         tnode += fNnodes;
272         // select zone of interest in the indexes array
273         Int_t *index = GetPointsIndexes(tnode);
274         Int_t nPoints = (tnode == 2*fNnodes) ? fNpoints%fBucketSize : fBucketSize;
275
276         printf("true index %d points %d\n", tnode, nPoints);
277
278         // draw data points
279         TGraph *g = new TGraph(nPoints);
280         g->SetMarkerStyle(3);
281         g->SetMarkerSize(.8);
282         for(int ip = 0; ip<nPoints; ip++) g->SetPoint(ip, fData[ax1][index[ip]], fData[ax2][index[ip]]);
283         g->Draw("p");
284
285         // draw estimation point
286         TMarker *m=new TMarker(fRefPoints[ax1][inode], fRefPoints[ax2][inode], 2);
287         m->SetMarkerColor(2);
288         m->Draw();
289         
290         // draw node contour
291         Float_t *bounds = GetBoundary(tnode);
292         TBox *n = new TBox(bounds[2*ax1], bounds[2*ax2], bounds[2*ax1+1], bounds[2*ax2+1]);
293         n->SetFillStyle(0);
294         n->Draw();
295         
296         return;
297 }
298