summaryrefslogtreecommitdiff
path: root/pkg/web/middleware.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/web/middleware.go')
-rw-r--r--pkg/web/middleware.go89
1 files changed, 89 insertions, 0 deletions
diff --git a/pkg/web/middleware.go b/pkg/web/middleware.go
new file mode 100644
index 0000000..3a74477
--- /dev/null
+++ b/pkg/web/middleware.go
@@ -0,0 +1,89 @@
+// Copyright (C) 2019 LEAP
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+package web
+
+import (
+ "0xacab.org/leap/vpnweb/pkg/auth/creds"
+ "0xacab.org/leap/vpnweb/pkg/config"
+ "encoding/json"
+ "github.com/auth0/go-jwt-middleware"
+ "github.com/dgrijalva/jwt-go"
+ "log"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+)
+
+const debugAuth string = "VPNWEB_DEBUG_AUTH"
+
+func AuthMiddleware(authenticationFunc func(*creds.Credentials) bool, opts *config.Opts) http.HandlerFunc {
+ debugAuth, exists := os.LookupEnv(debugAuth)
+ if !exists {
+ debugAuth = "false"
+ }
+ var authHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var c creds.Credentials
+ err := json.NewDecoder(r.Body).Decode(&c)
+ if err != nil {
+ log.Println("Auth request did not send valid json")
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ if c.User == "" || c.Password == "" {
+ log.Println("Auth request did not include user or password")
+ http.Error(w, "Missing user and/or password", http.StatusBadRequest)
+ return
+ }
+
+ valid := authenticationFunc(&c)
+
+ if !valid {
+ log.Println("Wrong auth for user", c.User)
+ http.Error(w, "Wrong user and/or password", http.StatusUnauthorized)
+ return
+ }
+
+ if strings.ToLower(debugAuth) == "yes" {
+ log.Println("Valid auth for user", c.User)
+ }
+ token := jwt.New(jwt.SigningMethodHS256)
+ claims := token.Claims.(jwt.MapClaims)
+ claims["expiration"] = time.Now().Add(time.Hour * 24).Unix()
+ tokenString, _ := token.SignedString([]byte(opts.AuthSecret))
+ w.Write([]byte(tokenString))
+ })
+ return authHandler
+}
+
+func RestrictedMiddleware(shouldProtect func() bool, handler func(w http.ResponseWriter, r *http.Request), opts *config.Opts) http.Handler {
+
+ jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{
+ ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
+ return []byte(opts.AuthSecret), nil
+ },
+ SigningMethod: jwt.SigningMethodHS256,
+ })
+
+ switch shouldProtect() {
+ case false:
+ return http.HandlerFunc(handler)
+ case true:
+ return jwtMiddleware.Handler(http.HandlerFunc(handler))
+ }
+ return nil
+}