summaryrefslogtreecommitdiff
path: root/weighted_dist.go
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-06-19 06:29:12 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-06-19 06:29:12 +0000
commit5abad1571c7d0869e29d55ca01df83fef8cd4606 (patch)
treef6bb184672a435f98b5c7c9713aaacb0d37c5ce0 /weighted_dist.go
parent6245391c93adf7b9617888d9d9ca8b12518cc52a (diff)
Use Vose's Alias Method to sample the weighted distribution.
The weight generation code also was cleaned up (and now can support generating distributions that look like what ScrambleSuit does as a compile time change). Per: http://www.keithschwarz.com/darts-dice-coins/
Diffstat (limited to 'weighted_dist.go')
-rw-r--r--weighted_dist.go180
1 files changed, 140 insertions, 40 deletions
diff --git a/weighted_dist.go b/weighted_dist.go
index 02fb26d..7c47cb8 100644
--- a/weighted_dist.go
+++ b/weighted_dist.go
@@ -28,6 +28,7 @@
package obfs4
import (
+ "container/list"
"fmt"
"math/rand"
@@ -36,27 +37,25 @@ import (
)
const (
- minBuckets = 1
- maxBuckets = 100
+ minValues = 1
+ maxValues = 100
)
// wDist is a weighted distribution.
type wDist struct {
- minValue int
- maxValue int
- values []int
- buckets []int64
- totalWeight int64
+ minValue int
+ maxValue int
+ values []int
+ weights []float64
- rng *rand.Rand
+ alias []int
+ prob []float64
}
// newWDist creates a weighted distribution of values ranging from min to max
// based on a HashDrbg initialized with seed.
func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
- w = new(wDist)
- w.minValue = min
- w.maxValue = max
+ w = &wDist{minValue: min, maxValue: max}
if max <= min {
panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
@@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
return
}
-// sample generates a random value according to the distribution.
-func (w *wDist) sample() int {
- retIdx := 0
- var totalWeight int64
- weight := csrand.Int63n(w.totalWeight)
- for i, bucketWeight := range w.buckets {
- totalWeight += bucketWeight
- if weight <= totalWeight {
- retIdx = i
- break
+// genValues creates a slice containing a random number of random values
+// that when scaled by adding minValue will fall into [min, max].
+func (w *wDist) genValues(rng *rand.Rand) {
+ nValues := (w.maxValue + 1) - w.minValue
+ values := rng.Perm(nValues)
+ if nValues < minValues {
+ nValues = minValues
+ }
+ if nValues > maxValues {
+ nValues = maxValues
+ }
+ nValues = rng.Intn(nValues) + 1
+ w.values = values[:nValues]
+}
+
+// genBiasedWeights generates a non-uniform weight list, similar to the
+// ScrambleSuit prob_dist module.
+func (w *wDist) genBiasedWeights(rng *rand.Rand) {
+ w.weights = make([]float64, len(w.values))
+
+ culmProb := 0.0
+ for i := range w.values {
+ p := (1.0 - culmProb) * rng.Float64()
+ w.weights[i] = p
+ culmProb += p
+ }
+}
+
+// genUniformWeights generates a uniform weight list.
+func (w *wDist) genUniformWeights(rng *rand.Rand) {
+ w.weights = make([]float64, len(w.values))
+ for i := range w.weights {
+ w.weights[i] = rng.Float64()
+ }
+}
+
+// genTables calculates the alias and prob tables used for Vose's Alias method.
+// Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/
+func (w *wDist) genTables() {
+ n := len(w.weights)
+ var sum float64
+ for _, weight := range w.weights {
+ sum += weight
+ }
+
+ // Create arrays $Alias$ and $Prob$, each of size $n$.
+ alias := make([]int, n)
+ prob := make([]float64, n)
+
+ // Create two worklists, $Small$ and $Large$.
+ small := list.New()
+ large := list.New()
+
+ scaled := make([]float64, n)
+ for i, weight := range w.weights {
+ // Multiply each probability by $n$.
+ p_i := weight * float64(n) / sum
+ scaled[i] = p_i
+
+ // For each scaled probability $p_i$:
+ if scaled[i] < 1.0 {
+ // If $p_i < 1$, add $i$ to $Small$.
+ small.PushBack(i)
+ } else {
+ // Otherwise ($p_i \ge 1$), add $i$ to $Large$.
+ large.PushBack(i)
}
}
- return w.minValue + w.values[retIdx]
+ // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
+ for small.Len() > 0 && large.Len() > 0 {
+ // Remove the first element from $Small$; call it $l$.
+ l := small.Remove(small.Front()).(int)
+ // Remove the first element from $Large$; call it $g$.
+ g := large.Remove(large.Front()).(int)
+
+ // Set $Prob[l] = p_l$.
+ prob[l] = scaled[l]
+ // Set $Alias[l] = g$.
+ alias[l] = g
+
+ // Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.)
+ scaled[g] = (scaled[g] + scaled[l]) - 1.0
+
+ if scaled[g] < 1.0 {
+ // If $p_g < 1$, add $g$ to $Small$.
+ small.PushBack(g)
+ } else {
+ // Otherwise ($p_g \ge 1$), add $g$ to $Large$.
+ large.PushBack(g)
+ }
+ }
+
+ // While $Large$ is not empty:
+ for large.Len() > 0 {
+ // Remove the first element from $Large$; call it $g$.
+ g := large.Remove(large.Front()).(int)
+ // Set $Prob[g] = 1$.
+ prob[g] = 1.0
+ }
+
+ // While $Small$ is not empty: This is only possible due to numerical instability.
+ for small.Len() > 0 {
+ // Remove the first element from $Small$; call it $l$.
+ l := small.Remove(small.Front()).(int)
+ // Set $Prob[l] = 1$.
+ prob[l] = 1.0
+ }
+
+ w.prob = prob
+ w.alias = alias
}
// reset generates a new distribution with the same min/max based on a new seed.
func (w *wDist) reset(seed *drbg.Seed) {
// Initialize the deterministic random number generator.
drbg := drbg.NewHashDrbg(seed)
- w.rng = rand.New(drbg)
+ rng := rand.New(drbg)
- nBuckets := (w.maxValue + 1) - w.minValue
- w.values = w.rng.Perm(nBuckets)
- if nBuckets < minBuckets {
- nBuckets = minBuckets
- }
- if nBuckets > maxBuckets {
- nBuckets = maxBuckets
- }
- nBuckets = w.rng.Intn(nBuckets) + 1
-
- w.totalWeight = 0
- w.buckets = make([]int64, nBuckets)
- for i, _ := range w.buckets {
- prob := w.rng.Int63n(1000)
- w.buckets[i] = prob
- w.totalWeight += prob
+ w.genValues(rng)
+ //w.genBiasedWeights(rng)
+ w.genUniformWeights(rng)
+ w.genTables()
+}
+
+// sample generates a random value according to the distribution.
+func (w *wDist) sample() int {
+ var idx int
+
+ // Generate a fair die roll from an $n$-sided die; call the side $i$.
+ i := csrand.Intn(len(w.values))
+ // Flip a biased coin that comes up heads with probability $Prob[i]$.
+ if csrand.Float64() <= w.prob[i] {
+ // If the coin comes up "heads," return $i$.
+ idx = i
+ } else {
+ // Otherwise, return $Alias[i]$.
+ idx = w.alias[i]
}
- w.buckets[len(w.buckets)-1] = w.totalWeight
+
+ return w.minValue + w.values[idx]
}
/* vim :set ts=4 sw=4 sts=4 noet : */