IO factorization for Interpolator
[u/mrichter/AliRoot.git] / STAT / Macros / testInterpolator.C
1 const Int_t ndim = 2;
2 Double_t testInterpolator(const Int_t nstat = 100000)
3 {
4 // Macro for testing the TKDInterpolator.
5 // 
6 // The function which it is interpolated is an uncorrelated landau
7 // distribution in "ndim" dimensions. The function returns the chi2 of
8 // the interpolation.
9 // 
10 // Parameters
11 //      nstat - number of points to be used for training
12 //      kBuild - on/off generate data
13 //      kTransform - on/off outliers compresion
14 //      kPCA - on/off pricipal component analysis 
15
16         const Bool_t kBuild = 1;
17         const Bool_t kTransform = 1;
18         const Bool_t kPCA = 1;
19         gStyle->SetPalette(1);
20         
21         Double_t pntTrain[ndim], pntTest[ndim], pntRotate[ndim];
22         Double_t pdf;
23         TFile *f = 0x0, *fEval = 0x0;
24         TTree *t = 0x0, *tEval = 0x0;
25
26
27         // build data
28         if(kBuild){
29                 printf("build data ... \n");
30                 f = TFile::Open(Form("%dD_LL.root", ndim), "RECREATE");
31                 t = new TTree("db", "Log-Log database");
32                 for(int idim=0; idim<ndim; idim++) t->Branch(Form("x%d", idim), &pntTrain[idim], Form("x%d/D", idim));
33                 
34                 fEval = TFile::Open(Form("%dD_Eval.root", ndim), "RECREATE");
35                 tEval = new TTree("db", "Evaluation database");
36                 for(int idim=0; idim<ndim; idim++) tEval->Branch(Form("x%d", idim), &pntTest[idim], Form("x%d/D", idim));
37                 
38                 for(int istat=0; istat<nstat; istat++){
39                         for(int idim=0; idim<ndim; idim++) pntTrain[idim] = gRandom->Landau(5.);
40                         if(!(istat%3)){ // one third of the statistics is for testing
41                                 memcpy(pntTest, pntTrain, ndim*sizeof(Double_t));
42                                 tEval->Fill();
43                                 continue;
44                         }
45                         if(kTransform)
46                                 for(int idim=0; idim<ndim; idim++)
47                                         if(pntTrain[idim] > 0.) pntTrain[idim] = TMath::Log(pntTrain[idim]);
48                                         else pntTrain[idim] = 0.;
49                         t->Fill();
50                 }
51                 f->cd();
52                 t->Write();
53                 f->Flush();
54                 
55                 fEval->cd();
56                 tEval->Write();
57                 fEval->Flush();
58         } else {// link data
59                 printf("link data ... \n");
60                 f = TFile::Open(Form("%dD_LL.root", ndim));
61                 t = (TTree*)f->Get("db");
62                 for(int idim=0; idim<ndim; idim++) t->SetBranchAddress(Form("x%d", idim), &pntTrain[idim]);
63                 
64                 fEval = TFile::Open(Form("%dD_Eval.root", ndim));
65                 tEval = (TTree*)fEval->Get("db");
66                 for(int idim=0; idim<ndim; idim++) tEval->SetBranchAddress(Form("x%d", idim), &pntTest[idim]);
67         }
68
69         
70         // do principal component analysis (PCA)
71         TPrincipal princ(ndim, "N");
72         if(kPCA && kBuild){
73                 printf("do principal component analysis (PCA) ... \n");
74                 f->cd();
75                 TTree *tt = new TTree("db1", "PCA database");
76                 for(int idim=0; idim<ndim; idim++) tt->Branch(Form("x%d", idim), &pntRotate[idim], Form("x%d/D", idim));
77                 for(int ientry=0; ientry<t->GetEntries(); ientry++){
78                         t->GetEntry(ientry);
79                         princ.AddRow(pntTrain);
80                 }
81                 princ.MakePrincipals();
82                 for(int ientry=0; ientry<t->GetEntries(); ientry++){
83                         t->GetEntry(ientry);
84                         princ.X2P(pntTrain, pntRotate);
85                         tt->Fill();
86                 }
87                 tt->Write();
88                 f->Flush();
89                 for(int idim=0; idim<ndim; idim++) tt->SetBranchAddress(Form("x%d", idim), &pntTrain[idim]);
90                 t = tt;
91         }
92         gROOT->cd();
93         
94         // do interpolation
95         printf("do interpolation ... \n");
96         Double_t pdf, pdf_estimate, chi2;
97         TString vl = "x0";
98         for(int idim=1; idim<ndim; idim++) vl+=Form(":x%d", idim);
99         TKDInterpolator in(t, vl.Data(), "", 200.);
100         chi2 = 0.;
101 /*      for(int ip=0; ip<tEval->GetEntries(); ip++){
102                 tEval->GetEntry(ip);
103                 printf("\nEval %d\n", ip);*/
104         TH1 *h1 = new TH2F("h1", "", 50, 0., 100., 50, 0., 100.);
105         TH1 *h2 = new TH2F("h2", "", 50, 0., 100., 50, 0., 100.);
106         TAxis *ax = h2->GetXaxis(), *ay = h2->GetYaxis();
107         for(int ix=2; ix<ax->GetNbins(); ix++){
108                 pntTest[0] = ax->GetBinCenter(ix);
109         for(int iy=2; iy<ay->GetNbins(); iy++){
110                 pntTest[1] = ay->GetBinCenter(iy);
111                 memcpy(pntTrain, pntTest, ndim*sizeof(Double_t));
112
113                 if(kTransform)
114                         for(int idim=0; idim<ndim; idim++)
115                                 if(pntTrain[idim] > 0.) pntTrain[idim] = TMath::Log(pntTrain[idim]);
116                                 else pntTrain[idim] = 0.;
117                 
118                 if(kPCA){
119                         princ.X2P(pntTrain, pntRotate);
120                         memcpy(pntTrain, pntRotate, ndim*sizeof(Double_t));
121                 }
122
123                 pdf_estimate = in.Eval(pntTrain, 30);
124                 // calculate chi2
125                 if(kTransform)
126                         for(int idim=0; idim<ndim; idim++)
127                                 if(pntTest[idim] > 0.) pdf_estimate /= pntTest[idim];
128                                 else continue; 
129                 
130                 h1->SetBinContent(ix, iy, pdf_estimate);
131                 
132                 pdf = 1.; for(int idim=0; idim<ndim; idim++) pdf *= TMath::Landau(pntTest[idim], 5.);
133                 h2->SetBinContent(ix, iy, pdf);
134                 pdf_estimate -= pdf;
135                 chi2 += pdf_estimate*pdf_estimate/pdf;
136         }}
137         f->Close(); delete f;
138         fEval->Close(); delete fEval;
139         
140         // results presentation
141         printf("chi2 = %f\n", chi2);
142         TCanvas *c = 0x0;
143         if(!(c = (TCanvas*)gROOT->FindObject("c"))){
144                 c = new TCanvas("c", "", 10, 10, 900, 500);
145                 c->Divide(2, 1);
146         }
147         c->cd(1);
148         h1->Draw("lego2"); h1->GetZaxis()->SetRangeUser(1.e-9, 5.e-2); gPad->SetLogz(); gPad->Modified(); gPad->Update();
149         
150         c->cd(2);
151         h2->Draw("lego2"); h2->GetZaxis()->SetRangeUser(1.e-9, 5.e-2); gPad->SetLogz(); gPad->Modified(); gPad->Update();
152         return chi2;
153 }
154