]> git.uio.no Git - u/mrichter/AliRoot.git/blob - STAT/TKDTree.h
Follow the compilation scheme of AliRoot and to fulfill the C++ effic
[u/mrichter/AliRoot.git] / STAT / TKDTree.h
1 #ifndef ROOT_TKDTree
2 #define ROOT_TKDTree
3
4 #ifndef ROOT_TObject
5 #include "TObject.h"
6 #endif
7
8 #include "TMath.h"
9 template <typename Index, typename Value> class TKDTree : public TObject
10 {
11 public:
12         enum{
13                 kDimMax = 6
14         };
15
16         TKDTree();
17         TKDTree(Index npoints, Index ndim, UInt_t bsize, Value **data);
18         ~TKDTree();
19         
20         // getters
21                                         inline  Index*  GetPointsIndexes(Int_t node) const {
22                                                 if(node < fNnodes) return 0x0;
23                                                 Int_t offset = (node >= fCrossNode) ? (node-fCrossNode)*fBucketSize : fOffset+(node-fNnodes)*fBucketSize;
24                                                 return &fIndPoints[offset];
25                                         }
26                                         inline Char_t           GetNodeAxis(Int_t id) const {return (id < 0 || id >= fNnodes) ? 0 : fNodes[id].fAxis;}
27                                         inline Value            GetNodeValue(Int_t id) const {return (id < 0 || id >= fNnodes) ? 0 : fNodes[id].fValue;}
28                                         inline Int_t            GetNNodes() const {return fNnodes;}
29                                         inline Int_t            GetNTerminalNodes() const {return fNpoints/fBucketSize + ((fNpoints%fBucketSize)?1:0);}
30                                         inline Value*           GetBoundaries() const {return fBoundaries;}
31                                         inline Value*           GetBoundary(const Int_t node) const {return &fBoundaries[node*2*fNDim];}
32         static                                  Int_t           GetIndex(Int_t row, Int_t collumn){return collumn+(1<<row);}
33         static  inline  void            GetCoord(Int_t index, Int_t &row, Int_t &collumn){for (row=0; index>=(16<<row);row+=4); for (; index>=(2<<row);row++);collumn= index-(1<<row);};
34                                                                         Bool_t  FindNearestNeighbors(const Value *point, const Int_t kNN, Index *&i, Value &d);
35                                                                         Index           FindNode(const Value * point);
36                                                                         void            FindPoint(Value * point, Index &index, Int_t &iter);
37                                                                         void            FindInRangeA(Value * point, Value * delta, Index *res , Index &npoints,Index & iter, Int_t bnode);
38                                                                         void            FindInRangeB(Value * point, Value * delta, Index *res , Index &npoints,Index & iter, Int_t bnode);
39                                         inline  void            FindBNodeA(Value * point, Value * delta, Int_t &inode);
40         //
41                                         inline  Bool_t  IsTerminal(Index inode){return (inode>=fNnodes);}
42         //
43                                                                         Value           KOrdStat(Index ntotal, Value *a, Index k, Index *index);
44                                                                         void            MakeBoundaries(Value *range = 0x0);
45                                                                         
46                                                                         void            SetData(Index npoints, Index ndim, UInt_t bsize, Value **data);
47         //
48                                                                         void            Spread(Index ntotal, Value *a, Index *index, Value &min, Value &max);
49
50 protected:
51                                                                         void            Build();  // build tree
52                                                                         
53 private:
54                                                                         TKDTree(const TKDTree &); // not implemented
55                                                                         TKDTree<Index, Value>& operator=(const TKDTree<Index, Value>&); // not implemented
56                                                                         void            CookBoundariesTerminal(Int_t parent_node, Bool_t left);
57
58 public:
59         struct TKDNode{
60                 Char_t fAxis;
61                 Value  fValue;
62         };
63
64 protected:
65         Bool_t  kDataOwner;  // Toggle ownership of the data
66         Int_t   fNnodes;     // size of node array
67         Index   fNDim;       // number of variables
68         Index   fNpoints;    // number of multidimensional points
69         Index   fBucketSize; // limit statistic for nodes 
70         Value   **fData;     //!
71         Value           *fRange;     //! range of data for each dimension     
72         Value           *fBoundaries;//! nodes boundaries - check class doc
73         TKDNode *fNodes;
74         Index   *fkNN;       //! k nearest neighbors
75
76 private:
77         Index   *fIndPoints; //! array of points indexes
78         Int_t   fRowT0;      // smallest terminal row
79         Int_t   fCrossNode;  // cross node
80         Int_t   fOffset;     // offset in fIndPoints
81
82         ClassDef(TKDTree, 1)  // KD tree
83 };
84
85
86 typedef TKDTree<Int_t, Double_t> TKDTreeID;
87 typedef TKDTree<Int_t, Float_t> TKDTreeIF;
88
89 //_________________________________________________________________
90 template <typename  Index, typename Value> void TKDTree<Index, Value>::FindBNodeA(Value *point, Value *delta, Int_t &inode){
91   //
92   // find the smallest node covering the full range - start
93   //
94   inode =0; 
95   for (;inode<fNnodes;){
96     TKDNode & node = fNodes[inode];
97     if (TMath::Abs(point[node.fAxis] - node.fValue)<delta[node.fAxis]) break;
98     inode = (point[node.fAxis] < node.fValue) ? (inode*2)+1: (inode*2)+2;    
99   }
100 }
101
102 #endif
103