summaryrefslogtreecommitdiff
path: root/vendor/github.com/hongshibao/go-kdtree/kdtree.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/hongshibao/go-kdtree/kdtree.go')
-rw-r--r--vendor/github.com/hongshibao/go-kdtree/kdtree.go200
1 files changed, 200 insertions, 0 deletions
diff --git a/vendor/github.com/hongshibao/go-kdtree/kdtree.go b/vendor/github.com/hongshibao/go-kdtree/kdtree.go
new file mode 100644
index 0000000..64ecd60
--- /dev/null
+++ b/vendor/github.com/hongshibao/go-kdtree/kdtree.go
@@ -0,0 +1,200 @@
+package kdtree
+
+import (
+ "container/heap"
+
+ "github.com/hongshibao/go-algo"
+)
+
+type Point interface {
+ // Return the total number of dimensions
+ Dim() int
+ // Return the value X_{dim}, dim is started from 0
+ GetValue(dim int) float64
+ // Return the distance between two points
+ Distance(p Point) float64
+ // Return the distance between the point and the plane X_{dim}=val
+ PlaneDistance(val float64, dim int) float64
+}
+
+type PointBase struct {
+ Point
+ Vec []float64
+}
+
+func (b PointBase) Dim() int {
+ return len(b.Vec)
+}
+
+func (b PointBase) GetValue(dim int) float64 {
+ return b.Vec[dim]
+}
+
+func NewPointBase(vals []float64) PointBase {
+ ret := PointBase{}
+ for _, val := range vals {
+ ret.Vec = append(ret.Vec, val)
+ }
+ return ret
+}
+
+type kdTreeNode struct {
+ axis int
+ splittingPoint Point
+ leftChild *kdTreeNode
+ rightChild *kdTreeNode
+}
+
+type KDTree struct {
+ root *kdTreeNode
+ dim int
+}
+
+func (t *KDTree) Dim() int {
+ return t.dim
+}
+
+func (t *KDTree) KNN(target Point, k int) []Point {
+ hp := &kNNHeapHelper{}
+ t.search(t.root, hp, target, k)
+ ret := make([]Point, 0, hp.Len())
+ for hp.Len() > 0 {
+ item := heap.Pop(hp).(*kNNHeapNode)
+ ret = append(ret, item.point)
+ }
+ for i := len(ret)/2 - 1; i >= 0; i-- {
+ opp := len(ret) - 1 - i
+ ret[i], ret[opp] = ret[opp], ret[i]
+ }
+ return ret
+}
+
+func (t *KDTree) search(p *kdTreeNode,
+ hp *kNNHeapHelper, target Point, k int) {
+ stk := make([]*kdTreeNode, 0)
+ for p != nil {
+ stk = append(stk, p)
+ if target.GetValue(p.axis) < p.splittingPoint.GetValue(p.axis) {
+ p = p.leftChild
+ } else {
+ p = p.rightChild
+ }
+ }
+ for i := len(stk) - 1; i >= 0; i-- {
+ cur := stk[i]
+ dist := target.Distance(cur.splittingPoint)
+ if hp.Len() < k || (*hp)[0].distance >= dist {
+ heap.Push(hp, &kNNHeapNode{
+ point: cur.splittingPoint,
+ distance: dist,
+ })
+ if hp.Len() > k {
+ heap.Pop(hp)
+ }
+ }
+ if hp.Len() < k || target.PlaneDistance(
+ cur.splittingPoint.GetValue(cur.axis), cur.axis) <=
+ (*hp)[0].distance {
+ if target.GetValue(cur.axis) < cur.splittingPoint.GetValue(cur.axis) {
+ t.search(cur.rightChild, hp, target, k)
+ } else {
+ t.search(cur.leftChild, hp, target, k)
+ }
+ }
+ }
+}
+
+func NewKDTree(points []Point) *KDTree {
+ if len(points) == 0 {
+ return nil
+ }
+ ret := &KDTree{
+ dim: points[0].Dim(),
+ root: createKDTree(points, 0),
+ }
+ return ret
+}
+
+func createKDTree(points []Point, depth int) *kdTreeNode {
+ if len(points) == 0 {
+ return nil
+ }
+ dim := points[0].Dim()
+ ret := &kdTreeNode{
+ axis: depth % dim,
+ }
+ if len(points) == 1 {
+ ret.splittingPoint = points[0]
+ return ret
+ }
+ idx := selectSplittingPoint(points, ret.axis)
+ if idx == -1 {
+ return nil
+ }
+ ret.splittingPoint = points[idx]
+ ret.leftChild = createKDTree(points[0:idx], depth+1)
+ ret.rightChild = createKDTree(points[idx+1:len(points)], depth+1)
+ return ret
+}
+
+type selectionHelper struct {
+ axis int
+ points []Point
+}
+
+func (h *selectionHelper) Len() int {
+ return len(h.points)
+}
+
+func (h *selectionHelper) Less(i, j int) bool {
+ return h.points[i].GetValue(h.axis) < h.points[j].GetValue(h.axis)
+}
+
+func (h *selectionHelper) Swap(i, j int) {
+ h.points[i], h.points[j] = h.points[j], h.points[i]
+}
+
+func selectSplittingPoint(points []Point, axis int) int {
+ helper := &selectionHelper{
+ axis: axis,
+ points: points,
+ }
+ mid := len(points)/2 + 1
+ err := algo.QuickSelect(helper, mid)
+ if err != nil {
+ return -1
+ }
+ return mid - 1
+}
+
+type kNNHeapNode struct {
+ point Point
+ distance float64
+}
+
+type kNNHeapHelper []*kNNHeapNode
+
+func (h kNNHeapHelper) Len() int {
+ return len(h)
+}
+
+func (h kNNHeapHelper) Less(i, j int) bool {
+ return h[i].distance > h[j].distance
+}
+
+func (h kNNHeapHelper) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *kNNHeapHelper) Push(x interface{}) {
+ item := x.(*kNNHeapNode)
+ *h = append(*h, item)
+}
+
+func (h *kNNHeapHelper) Pop() interface{} {
+ old := *h
+ n := len(old)
+ item := old[n-1]
+ *h = old[0 : n-1]
+ return item
+}