First version of kdtree (Alexander, Marian)
[u/mrichter/AliRoot.git] / STAT / TKDInterpolator.cxx
CommitLineData
f2040a8f 1#include "TKDInterpolator.h"
2
3#include "TLinearFitter.h"
4#include "TVector.h"
5#include "TTree.h"
6#include "TH2.h"
7#include "TObjArray.h"
8#include "TObjString.h"
9#include "TBox.h"
10#include "TGraph.h"
11#include "TMarker.h"
12
13
14
15ClassImp(TKDInterpolator)
16
17/////////////////////////////////////////////////////////////////////
18// Memory setup of protected data memebers
19// fRefPoints : evaluation point of PDF for each terminal node of underlying KD Tree.
20// | 1st terminal node (fNDim point coordinates) | 2nd terminal node (fNDim point coordinates) | ...
21//
22// fRefValues : evaluation value/error of PDF for each terminal node of underlying KD Tree.
23// | 1st terminal node (value) | 2nd terminal node (value) | ... | 1st terminal node (error) | 2nd terminal node (error) | ...
24/////////////////////////////////////////////////////////////////////
25
26//_________________________________________________________________
27TKDInterpolator::TKDInterpolator() : TKDTreeIF()
28 ,fNTNodes(0)
29 ,fRefPoints(0x0)
30 ,fRefValues(0x0)
31 ,fDepth(-1)
32 ,fTmpPoint(0x0)
33 ,fKDhelper(0x0)
34 ,fFitter(0x0)
35{
36}
37
38//_________________________________________________________________
39TKDInterpolator::TKDInterpolator(Int_t npoints, Int_t ndim, UInt_t bsize, Float_t **data) : TKDTreeIF(npoints, ndim, bsize, data)
40 ,fNTNodes(GetNTerminalNodes())
41 ,fRefPoints(0x0)
42 ,fRefValues(0x0)
43 ,fDepth(-1)
44 ,fTmpPoint(0x0)
45 ,fKDhelper(0x0)
46 ,fFitter(0x0)
47{
48 Build();
49}
50
51
52//_________________________________________________________________
53TKDInterpolator::TKDInterpolator(TTree *t, const Char_t *var, const Char_t *cut, UInt_t bsize) : TKDTreeIF()
54 ,fNTNodes(0)
55 ,fRefPoints(0x0)
56 ,fRefValues(0x0)
57 ,fDepth(-1)
58 ,fTmpPoint(0x0)
59 ,fKDhelper(0x0)
60 ,fFitter(0x0)
61{
62// Alocate data from a tree. The variables which have to be analysed are
63// defined in the "var" parameter as a colon separated list. The format should
64// be identical to that used by TTree::Draw().
65//
66//
67
68 fNpoints = t->GetEntriesFast();
69 TObjArray *vars = TString(var).Tokenize(":");
70 fNDim = vars->GetEntriesFast();
71 if(fNDim > 6/*kDimMax*/) Warning("TKDInterpolator(TTree*, UInt_t, const Char_t)", Form("Variable number exceed maximum dimension %d. Results are unpredictable.", 6/*kDimMax*/));
72 fBucketSize = bsize;
73
74 printf("Allocating %d points in %d dimensions.\n", fNpoints, fNDim);
75 Float_t *mem = new Float_t[fNDim*fNpoints];
76 fData = new Float_t*[fNDim];
77 for(int idim=0; idim<fNDim; idim++) fData[idim] = &mem[idim*fNpoints];
78 kDataOwner = kTRUE;
79
80 Double_t *v;
81 for(int idim=0; idim<fNDim; idim++){
82 if(!(t->Draw(((TObjString*)(*vars)[idim])->GetName(), cut, "goff"))){
83 Warning("TKDInterpolator(TTree*, UInt_t, const Char_t)", Form("Can not access data for %s", ((TObjString*)(*vars)[idim])->GetName() ));
84 continue;
85 }
86 v = t->GetV1();
87 for(int ip=0; ip<fNpoints; ip++) fData[idim][ip] = (Float_t)v[ip];
88 }
89 TKDTreeIF::Build();
90 fNTNodes = GetNTerminalNodes();
91 Build();
92}
93
94//_________________________________________________________________
95TKDInterpolator::~TKDInterpolator()
96{
97 if(fFitter) delete fFitter;
98 if(fKDhelper) delete fKDhelper;
99 if(fTmpPoint) delete [] fTmpPoint;
100
101 if(fRefPoints){
102 for(int idim=0; idim<fNDim; idim++) delete [] fRefPoints[idim] ;
103 delete [] fRefPoints;
104 }
105 if(fRefValues) delete [] fRefValues;
106}
107
108//_________________________________________________________________
109void TKDInterpolator::Build()
110{
111 if(!fBoundaries) MakeBoundaries();
112
113 // allocate memory for data
114 fRefValues = new Float_t[fNTNodes];
115 fRefPoints = new Float_t*[fNDim];
116 for(int id=0; id<fNDim; id++){
117 fRefPoints[id] = new Float_t[fNTNodes];
118 for(int in=0; in<fNTNodes; in++) fRefPoints[id][in] = 0.;
119 }
120
121 Float_t *bounds = 0x0;
122 Int_t *indexPoints;
123 for(int inode=0, tnode = fNnodes; inode<fNTNodes-1; inode++, tnode++){
124 fRefValues[inode] = Float_t(fBucketSize)/fNpoints;
125 bounds = GetBoundary(tnode);
126 for(int idim=0; idim<fNDim; idim++) fRefValues[inode] /= (bounds[2*idim+1] - bounds[2*idim]);
127
128 indexPoints = GetPointsIndexes(tnode);
129 // loop points in this terminal node
130 for(int idim=0; idim<fNDim; idim++){
131 for(int ip = 0; ip<fBucketSize; ip++) fRefPoints[idim][inode] += fData[idim][indexPoints[ip]];
132 fRefPoints[idim][inode] /= fBucketSize;
133 }
134 }
135
136 // analyze last (incomplete) terminal node
137 Int_t counts = fNpoints%fBucketSize;
138 counts = counts ? counts : fBucketSize;
139 Int_t inode = fNTNodes - 1, tnode = inode + fNnodes;
140 fRefValues[inode] = Float_t(counts)/fNpoints;
141 bounds = GetBoundary(tnode);
142 for(int idim=0; idim<fNDim; idim++) fRefValues[inode] /= (bounds[2*idim+1] - bounds[2*idim]);
143
144 indexPoints = GetPointsIndexes(tnode);
145 // loop points in this terminal node
146 for(int idim=0; idim<fNDim; idim++){
147 for(int ip = 0; ip<counts; ip++) fRefPoints[idim][inode] += fData[idim][indexPoints[ip]];
148 fRefPoints[idim][inode] /= counts;
149 }
150}
151
152//_________________________________________________________________
153Double_t TKDInterpolator::Eval(Float_t *point)
154{
155
156 // calculate number of parameters in the parabolic expresion
157 Int_t kNN = 1 + fNDim + fNDim*(fNDim+1)/2;
158
159 // prepare workers
160 if(!fTmpPoint) fTmpPoint = new Double_t[fNDim];
161 if(!fKDhelper) fKDhelper = new TKDTreeIF(GetNTerminalNodes(), fNDim, kNN, fRefPoints);
162 if(!fFitter){
163 // generate formula for nD
164 TString formula("1");
165 for(int idim=0; idim<fNDim; idim++){
166 formula += Form("++x[%d]", idim);
167 for(int jdim=idim; jdim<fNDim; jdim++) formula += Form("++x[%d]*x[%d]", idim, jdim);
168 }
169 fFitter = new TLinearFitter(kNN, formula.Data());
170 }
171
172 Int_t kNN_old = 0;
173 Int_t *index;
174 Float_t dist;
175 fFitter->ClearPoints();
176 do{
177 if(!fKDhelper->FindNearestNeighbors(point, kNN, index, dist)){
178 Error("Eval()", Form("Failed retriving %d neighbours for point:", kNN));
179 for(int idim=0; idim<fNDim; idim++) printf("%f ", point[idim]);
180 printf("\n");
181 return -1;
182 }
183 for(int in=kNN_old; in<kNN; in++){
184 for(int idim=0; idim<fNDim; idim++) fTmpPoint[idim] = fRefPoints[idim][index[in]];
185 fFitter->AddPoint(fTmpPoint, TMath::Log(fRefValues[index[in]]), 1.);
186 }
187 kNN_old = kNN;
188 kNN += 4;
189 } while(fFitter->Eval());
190
191 // calculate evaluation
192 TVectorD par(kNN);
193 fFitter->GetParameters(par);
194 Double_t result = par[0];
195 Int_t ipar = 0;
196 for(int idim=0; idim<fNDim; idim++){
197 result += par[++ipar]*point[idim];
198 for(int jdim=idim; jdim<fNDim; jdim++) result += par[++ipar]*point[idim]*point[jdim];
199 }
200 return TMath::Exp(result);
201}
202
203
204//_________________________________________________________________
205void TKDInterpolator::DrawNodes(Int_t depth, Int_t ax1, Int_t ax2)
206{
207// Draw nodes structure projected on plane "ax1:ax2". The parameter
208// "depth" specifies the bucket size per node. If depth == -1 draw only
209// terminal nodes and evaluation points (default -1 i.e. bucket size per node equal bucket size specified by the user)
210
211 if(!fBoundaries) MakeBoundaries();
212
213 // Count nodes in specific view
214 Int_t nnodes = 0;
215 for(int inode = 0; inode <= 2*fNnodes; inode++){
216 if(depth == -1){
217 if(!IsTerminal(inode)) continue;
218 } else if((inode+1) >> depth != 1) continue;
219 nnodes++;
220 }
221
222 //printf("depth %d nodes %d\n", depth, nnodes);
223
224 //TH2 *h2 = new TH2S("hframe", "", 100, fRange[2*ax1], fRange[2*ax1+1], 100, fRange[2*ax2], fRange[2*ax2+1]);
225 TH2 *h2 = new TH2S("hframe", "", 100, 0., 1., 100, 0., 1.);
226 h2->Draw();
227
228 const Float_t border = 0.;//1.E-4;
229 TBox **node_array = new TBox*[nnodes], *node;
230 Float_t *bounds = 0x0;
231 nnodes = 0;
232 for(int inode = 0; inode <= 2*fNnodes; inode++){
233 if(depth == -1){
234 if(!IsTerminal(inode)) continue;
235 } else if((inode+1) >> depth != 1) continue;
236
237 node = node_array[nnodes++];
238 bounds = GetBoundary(inode);
239 node = new TBox(bounds[2*ax1]+border, bounds[2*ax2]+border, bounds[2*ax1+1]-border, bounds[2*ax2+1]-border);
240 node->SetFillStyle(0);
241 node->SetFillColor(51+inode);
242 node->Draw();
243 }
244 if(depth != -1) return;
245
246 // Draw reference points
247 TGraph *ref = new TGraph(GetNTerminalNodes());
248 ref->SetMarkerStyle(2);
249 ref->SetMarkerColor(2);
250 Float_t val, error;
251 for(int inode = 0; inode < GetNTerminalNodes(); inode++) ref->SetPoint(inode, fRefPoints[ax1][inode], fRefPoints[ax2][inode]);
252 ref->Draw("p");
253 return;
254}
255
256//_________________________________________________________________
257void TKDInterpolator::DrawNode(Int_t tnode, Int_t ax1, Int_t ax2)
258{
259// Draw node "node" and the data points within.
260
261 if(tnode < 0 || tnode >= GetNTerminalNodes()){
262 Warning("DrawNode()", Form("Terminal node %d outside defined range.", tnode));
263 return;
264 }
265
266 //TH2 *h2 = new TH2S("hframe", "", 1, fRange[2*ax1], fRange[2*ax1+1], 1, fRange[2*ax2], fRange[2*ax2+1]);
267 TH2 *h2 = new TH2S("hframe", "", 1, 0., 1., 1, 0., 1.);
268 h2->Draw();
269
270 Int_t inode = tnode;
271 tnode += fNnodes;
272 // select zone of interest in the indexes array
273 Int_t *index = GetPointsIndexes(tnode);
274 Int_t nPoints = (tnode == 2*fNnodes) ? fNpoints%fBucketSize : fBucketSize;
275
276 printf("true index %d points %d\n", tnode, nPoints);
277
278 // draw data points
279 TGraph *g = new TGraph(nPoints);
280 g->SetMarkerStyle(3);
281 g->SetMarkerSize(.8);
282 for(int ip = 0; ip<nPoints; ip++) g->SetPoint(ip, fData[ax1][index[ip]], fData[ax2][index[ip]]);
283 g->Draw("p");
284
285 // draw estimation point
286 TMarker *m=new TMarker(fRefPoints[ax1][inode], fRefPoints[ax2][inode], 2);
287 m->SetMarkerColor(2);
288 m->Draw();
289
290 // draw node contour
291 Float_t *bounds = GetBoundary(tnode);
292 TBox *n = new TBox(bounds[2*ax1], bounds[2*ax2], bounds[2*ax1+1], bounds[2*ax2+1]);
293 n->SetFillStyle(0);
294 n->Draw();
295
296 return;
297}
298