diff options
author | Kali Kaneko (leap communications) <kali@leap.se> | 2018-12-13 19:45:48 +0100 |
---|---|---|
committer | Kali Kaneko (leap communications) <kali@leap.se> | 2018-12-13 19:47:11 +0100 |
commit | 81843a0477c5eba4b3d4f6ff932f748f9e1244e7 (patch) | |
tree | 6068ae0ff1d8ac617e2781fb8aa68200e71403f2 /vendor/github.com/hongshibao/go-kdtree/kdtree.go | |
parent | ac56a90e21118dc1ddfc1bd5e69d87f157710412 (diff) |
vendor packages
Diffstat (limited to 'vendor/github.com/hongshibao/go-kdtree/kdtree.go')
-rw-r--r-- | vendor/github.com/hongshibao/go-kdtree/kdtree.go | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/vendor/github.com/hongshibao/go-kdtree/kdtree.go b/vendor/github.com/hongshibao/go-kdtree/kdtree.go new file mode 100644 index 0000000..64ecd60 --- /dev/null +++ b/vendor/github.com/hongshibao/go-kdtree/kdtree.go @@ -0,0 +1,200 @@ +package kdtree + +import ( + "container/heap" + + "github.com/hongshibao/go-algo" +) + +type Point interface { + // Return the total number of dimensions + Dim() int + // Return the value X_{dim}, dim is started from 0 + GetValue(dim int) float64 + // Return the distance between two points + Distance(p Point) float64 + // Return the distance between the point and the plane X_{dim}=val + PlaneDistance(val float64, dim int) float64 +} + +type PointBase struct { + Point + Vec []float64 +} + +func (b PointBase) Dim() int { + return len(b.Vec) +} + +func (b PointBase) GetValue(dim int) float64 { + return b.Vec[dim] +} + +func NewPointBase(vals []float64) PointBase { + ret := PointBase{} + for _, val := range vals { + ret.Vec = append(ret.Vec, val) + } + return ret +} + +type kdTreeNode struct { + axis int + splittingPoint Point + leftChild *kdTreeNode + rightChild *kdTreeNode +} + +type KDTree struct { + root *kdTreeNode + dim int +} + +func (t *KDTree) Dim() int { + return t.dim +} + +func (t *KDTree) KNN(target Point, k int) []Point { + hp := &kNNHeapHelper{} + t.search(t.root, hp, target, k) + ret := make([]Point, 0, hp.Len()) + for hp.Len() > 0 { + item := heap.Pop(hp).(*kNNHeapNode) + ret = append(ret, item.point) + } + for i := len(ret)/2 - 1; i >= 0; i-- { + opp := len(ret) - 1 - i + ret[i], ret[opp] = ret[opp], ret[i] + } + return ret +} + +func (t *KDTree) search(p *kdTreeNode, + hp *kNNHeapHelper, target Point, k int) { + stk := make([]*kdTreeNode, 0) + for p != nil { + stk = append(stk, p) + if target.GetValue(p.axis) < p.splittingPoint.GetValue(p.axis) { + p = p.leftChild + } else { + p = p.rightChild + } + } + for i := len(stk) - 1; i >= 0; i-- { + cur := stk[i] + dist := target.Distance(cur.splittingPoint) + if hp.Len() < k || (*hp)[0].distance >= dist { + heap.Push(hp, &kNNHeapNode{ + point: cur.splittingPoint, + distance: dist, + }) + if hp.Len() > k { + heap.Pop(hp) + } + } + if hp.Len() < k || target.PlaneDistance( + cur.splittingPoint.GetValue(cur.axis), cur.axis) <= + (*hp)[0].distance { + if target.GetValue(cur.axis) < cur.splittingPoint.GetValue(cur.axis) { + t.search(cur.rightChild, hp, target, k) + } else { + t.search(cur.leftChild, hp, target, k) + } + } + } +} + +func NewKDTree(points []Point) *KDTree { + if len(points) == 0 { + return nil + } + ret := &KDTree{ + dim: points[0].Dim(), + root: createKDTree(points, 0), + } + return ret +} + +func createKDTree(points []Point, depth int) *kdTreeNode { + if len(points) == 0 { + return nil + } + dim := points[0].Dim() + ret := &kdTreeNode{ + axis: depth % dim, + } + if len(points) == 1 { + ret.splittingPoint = points[0] + return ret + } + idx := selectSplittingPoint(points, ret.axis) + if idx == -1 { + return nil + } + ret.splittingPoint = points[idx] + ret.leftChild = createKDTree(points[0:idx], depth+1) + ret.rightChild = createKDTree(points[idx+1:len(points)], depth+1) + return ret +} + +type selectionHelper struct { + axis int + points []Point +} + +func (h *selectionHelper) Len() int { + return len(h.points) +} + +func (h *selectionHelper) Less(i, j int) bool { + return h.points[i].GetValue(h.axis) < h.points[j].GetValue(h.axis) +} + +func (h *selectionHelper) Swap(i, j int) { + h.points[i], h.points[j] = h.points[j], h.points[i] +} + +func selectSplittingPoint(points []Point, axis int) int { + helper := &selectionHelper{ + axis: axis, + points: points, + } + mid := len(points)/2 + 1 + err := algo.QuickSelect(helper, mid) + if err != nil { + return -1 + } + return mid - 1 +} + +type kNNHeapNode struct { + point Point + distance float64 +} + +type kNNHeapHelper []*kNNHeapNode + +func (h kNNHeapHelper) Len() int { + return len(h) +} + +func (h kNNHeapHelper) Less(i, j int) bool { + return h[i].distance > h[j].distance +} + +func (h kNNHeapHelper) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *kNNHeapHelper) Push(x interface{}) { + item := x.(*kNNHeapNode) + *h = append(*h, item) +} + +func (h *kNNHeapHelper) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + *h = old[0 : n-1] + return item +} |