summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYawning Angel <yawning@schwanenlied.me>2014-06-19 06:29:12 +0000
committerYawning Angel <yawning@schwanenlied.me>2014-06-19 06:29:12 +0000
commit5abad1571c7d0869e29d55ca01df83fef8cd4606 (patch)
treef6bb184672a435f98b5c7c9713aaacb0d37c5ce0
parent6245391c93adf7b9617888d9d9ca8b12518cc52a (diff)
Use Vose's Alias Method to sample the weighted distribution.
The weight generation code also was cleaned up (and now can support generating distributions that look like what ScrambleSuit does as a compile time change). Per: http://www.keithschwarz.com/darts-dice-coins/
-rw-r--r--csrand/csrand.go6
-rw-r--r--weighted_dist.go180
-rw-r--r--weighted_dist_test.go82
3 files changed, 225 insertions, 43 deletions
diff --git a/csrand/csrand.go b/csrand/csrand.go
index a3299aa..b059ed0 100644
--- a/csrand/csrand.go
+++ b/csrand/csrand.go
@@ -68,9 +68,9 @@ func (r csRandSource) Seed(seed int64) {
// No-op.
}
-// Int63n returns, as a int64, a pseudo random number in [0, n).
-func Int63n(n int64) int64 {
- return CsRand.Int63n(n)
+// Intn returns, as a int, a pseudo random number in [0, n).
+func Intn(n int) int {
+ return CsRand.Intn(n)
}
// Float64 returns, as a float64, a pesudo random number in [0.0,1.0).
diff --git a/weighted_dist.go b/weighted_dist.go
index 02fb26d..7c47cb8 100644
--- a/weighted_dist.go
+++ b/weighted_dist.go
@@ -28,6 +28,7 @@
package obfs4
import (
+ "container/list"
"fmt"
"math/rand"
@@ -36,27 +37,25 @@ import (
)
const (
- minBuckets = 1
- maxBuckets = 100
+ minValues = 1
+ maxValues = 100
)
// wDist is a weighted distribution.
type wDist struct {
- minValue int
- maxValue int
- values []int
- buckets []int64
- totalWeight int64
+ minValue int
+ maxValue int
+ values []int
+ weights []float64
- rng *rand.Rand
+ alias []int
+ prob []float64
}
// newWDist creates a weighted distribution of values ranging from min to max
// based on a HashDrbg initialized with seed.
func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
- w = new(wDist)
- w.minValue = min
- w.maxValue = max
+ w = &wDist{minValue: min, maxValue: max}
if max <= min {
panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
@@ -67,46 +66,147 @@ func newWDist(seed *drbg.Seed, min, max int) (w *wDist) {
return
}
-// sample generates a random value according to the distribution.
-func (w *wDist) sample() int {
- retIdx := 0
- var totalWeight int64
- weight := csrand.Int63n(w.totalWeight)
- for i, bucketWeight := range w.buckets {
- totalWeight += bucketWeight
- if weight <= totalWeight {
- retIdx = i
- break
+// genValues creates a slice containing a random number of random values
+// that when scaled by adding minValue will fall into [min, max].
+func (w *wDist) genValues(rng *rand.Rand) {
+ nValues := (w.maxValue + 1) - w.minValue
+ values := rng.Perm(nValues)
+ if nValues < minValues {
+ nValues = minValues
+ }
+ if nValues > maxValues {
+ nValues = maxValues
+ }
+ nValues = rng.Intn(nValues) + 1
+ w.values = values[:nValues]
+}
+
+// genBiasedWeights generates a non-uniform weight list, similar to the
+// ScrambleSuit prob_dist module.
+func (w *wDist) genBiasedWeights(rng *rand.Rand) {
+ w.weights = make([]float64, len(w.values))
+
+ culmProb := 0.0
+ for i := range w.values {
+ p := (1.0 - culmProb) * rng.Float64()
+ w.weights[i] = p
+ culmProb += p
+ }
+}
+
+// genUniformWeights generates a uniform weight list.
+func (w *wDist) genUniformWeights(rng *rand.Rand) {
+ w.weights = make([]float64, len(w.values))
+ for i := range w.weights {
+ w.weights[i] = rng.Float64()
+ }
+}
+
+// genTables calculates the alias and prob tables used for Vose's Alias method.
+// Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/
+func (w *wDist) genTables() {
+ n := len(w.weights)
+ var sum float64
+ for _, weight := range w.weights {
+ sum += weight
+ }
+
+ // Create arrays $Alias$ and $Prob$, each of size $n$.
+ alias := make([]int, n)
+ prob := make([]float64, n)
+
+ // Create two worklists, $Small$ and $Large$.
+ small := list.New()
+ large := list.New()
+
+ scaled := make([]float64, n)
+ for i, weight := range w.weights {
+ // Multiply each probability by $n$.
+ p_i := weight * float64(n) / sum
+ scaled[i] = p_i
+
+ // For each scaled probability $p_i$:
+ if scaled[i] < 1.0 {
+ // If $p_i < 1$, add $i$ to $Small$.
+ small.PushBack(i)
+ } else {
+ // Otherwise ($p_i \ge 1$), add $i$ to $Large$.
+ large.PushBack(i)
}
}
- return w.minValue + w.values[retIdx]
+ // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
+ for small.Len() > 0 && large.Len() > 0 {
+ // Remove the first element from $Small$; call it $l$.
+ l := small.Remove(small.Front()).(int)
+ // Remove the first element from $Large$; call it $g$.
+ g := large.Remove(large.Front()).(int)
+
+ // Set $Prob[l] = p_l$.
+ prob[l] = scaled[l]
+ // Set $Alias[l] = g$.
+ alias[l] = g
+
+ // Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.)
+ scaled[g] = (scaled[g] + scaled[l]) - 1.0
+
+ if scaled[g] < 1.0 {
+ // If $p_g < 1$, add $g$ to $Small$.
+ small.PushBack(g)
+ } else {
+ // Otherwise ($p_g \ge 1$), add $g$ to $Large$.
+ large.PushBack(g)
+ }
+ }
+
+ // While $Large$ is not empty:
+ for large.Len() > 0 {
+ // Remove the first element from $Large$; call it $g$.
+ g := large.Remove(large.Front()).(int)
+ // Set $Prob[g] = 1$.
+ prob[g] = 1.0
+ }
+
+ // While $Small$ is not empty: This is only possible due to numerical instability.
+ for small.Len() > 0 {
+ // Remove the first element from $Small$; call it $l$.
+ l := small.Remove(small.Front()).(int)
+ // Set $Prob[l] = 1$.
+ prob[l] = 1.0
+ }
+
+ w.prob = prob
+ w.alias = alias
}
// reset generates a new distribution with the same min/max based on a new seed.
func (w *wDist) reset(seed *drbg.Seed) {
// Initialize the deterministic random number generator.
drbg := drbg.NewHashDrbg(seed)
- w.rng = rand.New(drbg)
+ rng := rand.New(drbg)
- nBuckets := (w.maxValue + 1) - w.minValue
- w.values = w.rng.Perm(nBuckets)
- if nBuckets < minBuckets {
- nBuckets = minBuckets
- }
- if nBuckets > maxBuckets {
- nBuckets = maxBuckets
- }
- nBuckets = w.rng.Intn(nBuckets) + 1
-
- w.totalWeight = 0
- w.buckets = make([]int64, nBuckets)
- for i, _ := range w.buckets {
- prob := w.rng.Int63n(1000)
- w.buckets[i] = prob
- w.totalWeight += prob
+ w.genValues(rng)
+ //w.genBiasedWeights(rng)
+ w.genUniformWeights(rng)
+ w.genTables()
+}
+
+// sample generates a random value according to the distribution.
+func (w *wDist) sample() int {
+ var idx int
+
+ // Generate a fair die roll from an $n$-sided die; call the side $i$.
+ i := csrand.Intn(len(w.values))
+ // Flip a biased coin that comes up heads with probability $Prob[i]$.
+ if csrand.Float64() <= w.prob[i] {
+ // If the coin comes up "heads," return $i$.
+ idx = i
+ } else {
+ // Otherwise, return $Alias[i]$.
+ idx = w.alias[i]
}
- w.buckets[len(w.buckets)-1] = w.totalWeight
+
+ return w.minValue + w.values[idx]
}
/* vim :set ts=4 sw=4 sts=4 noet : */
diff --git a/weighted_dist_test.go b/weighted_dist_test.go
new file mode 100644
index 0000000..14fecec
--- /dev/null
+++ b/weighted_dist_test.go
@@ -0,0 +1,82 @@
+/*
+ * Copyright (c) 2014, Yawning Angel <yawning at torproject dot org>
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright notice,
+ * this list of conditions and the following disclaimer.
+ *
+ * * Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+ * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+
+package obfs4
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/yawning/obfs4/drbg"
+)
+
+const debug = false
+
+func TestWeightedDist(t *testing.T) {
+ seed, err := drbg.NewSeed()
+ if err != nil {
+ t.Fatal("failed to generate a DRBG seed:", err)
+ }
+
+ const nrTrials = 1000000
+
+ hist := make([]int, 1000)
+
+ w := newWDist(seed, 0, 999)
+ if debug {
+ // Dump a string representation of the probability table.
+ fmt.Println("Table:")
+ var sum float64
+ for _, weight := range w.weights {
+ sum += weight
+ }
+ for i, weight := range w.weights {
+ p := weight / sum
+ if p > 0.000001 { // Filter out tiny values.
+ fmt.Printf(" [%d]: %f\n", w.minValue+w.values[i], p)
+ }
+ }
+ fmt.Println()
+ }
+
+ for i := 0; i < nrTrials; i++ {
+ value := w.sample()
+ hist[value]++
+ }
+
+ if debug {
+ fmt.Println("Generated:")
+ for value, count := range hist {
+ if count != 0 {
+ p := float64(count) / float64(nrTrials)
+ fmt.Printf(" [%d]: %f (%d)\n", value, p, count)
+ }
+ }
+ }
+}
+
+/* vim :set ts=4 sw=4 sts=4 noet : */