From a0bc2768c04255e13ef87798d87e4916d7bf17fc Mon Sep 17 00:00:00 2001 From: "kali kaneko (leap communications)" Date: Mon, 24 Feb 2020 18:53:18 +0100 Subject: [tests] return auth errors and add unit tests for message parsing --- pkg/auth/sip2/client.go | 27 +++++++++++---- pkg/auth/sip2/spec.go | 84 ++++++++++++++++++++++++++++++++++++++-------- pkg/auth/sip2/spec_test.go | 66 ++++++++++++++++++++++++++++++++++++ pkg/web/middleware.go | 4 +-- 4 files changed, 159 insertions(+), 22 deletions(-) create mode 100644 pkg/auth/sip2/spec_test.go diff --git a/pkg/auth/sip2/client.go b/pkg/auth/sip2/client.go index 623dc12..567d908 100644 --- a/pkg/auth/sip2/client.go +++ b/pkg/auth/sip2/client.go @@ -213,17 +213,32 @@ func (c *sipClient) CheckCredentials(credentials *creds.Credentials) (bool, erro statusMsg, err = c.parseResponse(resp) if err != nil { - log.Println("Error while parsing response") + log.Println("Error while parsing response:", resp) return false, err } - if value, ok := c.parser.getFieldValue(statusMsg, validPatron); ok && value == yes { - if value, ok := c.parser.getFieldValue(statusMsg, validPatronPassword); ok && value == yes { + if valid, err := isValidUser(statusMsg); valid { + if valid, err := isValidPassword(statusMsg); valid { return true, nil + } else { + return false, err } + } else { + return false, err } +} - // TODO log whatever error we can find (AF, Screen Message, for instance) - log.Printf("AUTH ERROR. RESPONSE: %s\n", resp) +func isValidUser(m *message) (bool, error) { + value, ok := m.getFieldValue(validPatron) + if !ok { + return false, errors.New("parse error: expected BL field") + } + return toBool(value) +} - return false, errors.New("unknown error while checking credentials") +func isValidPassword(m *message) (bool, error) { + value, ok := m.getFieldValue(validPatronPassword) + if !ok { + return false, errors.New("parse error: expected CQ field") + } + return toBool(value) } diff --git a/pkg/auth/sip2/spec.go b/pkg/auth/sip2/spec.go index 80fbe7a..07db87f 100644 --- a/pkg/auth/sip2/spec.go +++ b/pkg/auth/sip2/spec.go @@ -24,19 +24,32 @@ import ( const ( yes string = "Y" + no string = "N" trueVal string = "1" + falseVal string = "0" okVal string = "ok" language string = "language" patronStatus string = "patron status" + onlineStatus string = "on-line status" + checkinOk string = "checkin ok" + checkoutOk string = "checkout ok" + renewalPolicy string = "acs renewal policy" + statusUpdate string = "status update ok" + offlineOK string = "offline ok" + timeoutPeriod string = "timeout period" + retriesAllowed string = "retries allowed" + dateTimeSync string = "date/time sync" date string = "transaction date" patronIdentifier string = "patron identifier" patronPassword string = "patron password" + protocolVersion string = "protocol version" personalName string = "personal name" screenMessage string = "screen message" institutionID string = "institution id" validPatron string = "valid patron" validPatronPassword string = "valid patron password" loginResponse string = "Login Response" + ascStatus string = "ASC Status" patronStatusResponse string = "Patron Status Response" ) @@ -66,12 +79,56 @@ type messageSpec struct { fields []fixedFieldSpec } +func toBool(val string) (bool, error) { + var ret bool + switch val { + case trueVal: + ret = true + case falseVal: + ret = false + case yes: + ret = true + case no: + ret = false + default: + return false, errors.New("cannot parse value") + } + return ret, nil +} + type message struct { fields []variableField fixedFields []fixedField msgTxt string } +func (m *message) getFieldValue(field string) (string, bool) { + for _, v := range m.fields { + if v.spec.label == field { + return v.value, true + } + } + return "", false +} + +func (m *message) getFixedFieldValue(field string) (string, bool) { + for _, v := range m.fixedFields { + if v.spec.label == field { + return v.value, true + } + } + return "", false +} + +func (m *message) getValueByCode(code string) (string, bool) { + for _, v := range m.fields { + if v.spec.id == code { + return v.value, true + } + } + return "", false +} + type Parser struct { msgByCodeMap map[int]messageSpec variableFieldByCodeMap map[string]variableFieldSpec @@ -84,9 +141,21 @@ func getParser() *Parser { dateSpec := fixedFieldSpec{18, date} okSpec := fixedFieldSpec{1, okVal} + onlineStatusSpec := fixedFieldSpec{1, onlineStatus} + checkinOkSpec := fixedFieldSpec{1, checkinOk} + checkoutOkSpec := fixedFieldSpec{1, checkoutOk} + renewalSpec := fixedFieldSpec{1, renewalPolicy} + stUpdateSpec := fixedFieldSpec{1, statusUpdate} + offlineOkSpec := fixedFieldSpec{1, offlineOK} + timeoutSpec := fixedFieldSpec{3, timeoutPeriod} + retriesSpec := fixedFieldSpec{3, retriesAllowed} + dateTimeSyncSpec := fixedFieldSpec{18, dateTimeSync} + protoSpec := fixedFieldSpec{4, protocolVersion} + msgByCodeMap := map[int]messageSpec{ - 94: messageSpec{94, loginResponse, []fixedFieldSpec{okSpec}}, 24: messageSpec{24, patronStatusResponse, []fixedFieldSpec{patronStatusSpec, languageSpec, dateSpec}}, + 94: messageSpec{94, loginResponse, []fixedFieldSpec{okSpec}}, + 98: messageSpec{98, ascStatus, []fixedFieldSpec{onlineStatusSpec, checkinOkSpec, checkoutOkSpec, renewalSpec, stUpdateSpec, offlineOkSpec, timeoutSpec, retriesSpec, dateTimeSyncSpec, protoSpec}}, } variableFieldByCodeMap := map[string]variableFieldSpec{ @@ -119,20 +188,7 @@ func (p *Parser) getFixedFieldValue(msg *message, field string) (string, bool) { return "", false } -func (p *Parser) getFieldValue(msg *message, field string) (string, bool) { - for _, v := range msg.fields { - if v.spec.label == field { - return v.value, true - } - } - return "", false -} - func (p *Parser) parseMessage(msg string) (*message, error) { - /* FIXME */ - /* - http: panic serving 186.26.116.7:1292: runtime error: slice bounds out of range - */ if len(msg) == 0 { return &message{}, errors.New("empty message") } diff --git a/pkg/auth/sip2/spec_test.go b/pkg/auth/sip2/spec_test.go new file mode 100644 index 0000000..92fbdd9 --- /dev/null +++ b/pkg/auth/sip2/spec_test.go @@ -0,0 +1,66 @@ +package sip2 + +import ( + "testing" +) + +const ( + invalidCard = "24YYYY 00020200220 173142AE|AAaaaa|BLN|AFInvalid cardnumber|AOtestlibrary|" + invalidPass = "24 00020200221 185454AE MrUser|AA01000|BLY|CQN|AFGreetings from Koha. -- Invalid password|AOtestlibrary|" + authOK = "24 00020200224 172540AE MrUser|AA01000|BLY|CQY|AFGreetings from Koha. |AOtestlibrary|" + statusOK = "98YYYYNN10000520200221 1853422.00AOtestlibrary|BXYYYYYYYYYYYNYYYY|" +) + +func doParse(txt string) (*message, error) { + p := getParser() + msg, err := p.parseMessage(txt) + return msg, err +} + +func TestInvalidCard(t *testing.T) { + msg, err := doParse(invalidCard) + if err != nil { + t.Fatal("unexpected error", err) + } + validUser, err := isValidUser(msg) + if validUser == true { + t.Fatal("expected invalid user") + } +} + +func TestInvalidPass(t *testing.T) { + msg, err := doParse(invalidPass) + if err != nil { + t.Fatal("unexpected error", err) + } + validPass, err := isValidPassword(msg) + if validPass == true { + t.Fatal("expected invalid pas") + } +} + +func TestAuthOK(t *testing.T) { + msg, err := doParse(authOK) + if err != nil { + t.Fatal("unexpected error", err) + } + validUser, err := isValidUser(msg) + if validUser != true { + t.Fatal("expected valid user") + } + validPass, err := isValidPassword(msg) + if validPass != true { + t.Fatal("expected valid pass") + } +} + +func TestStatusOK(t *testing.T) { + msg, err := doParse(statusOK) + if err != nil { + t.Fatal("unexpected error", err) + } + proto, _ := msg.getFixedFieldValue(protocolVersion) + if proto != "2.00" { + t.Fatal("expected protocol 2.00") + } +} diff --git a/pkg/web/middleware.go b/pkg/web/middleware.go index ed137d6..3ff8938 100644 --- a/pkg/web/middleware.go +++ b/pkg/web/middleware.go @@ -64,14 +64,14 @@ func AuthMiddleware(authenticationFunc func(*creds.Credentials) (bool, error), o if err != nil { metrics.UnavailableLogins.Inc() log.Println("Error while checking credentials: ", err) - http.Error(w, "Auth service unavailable", http.StatusServiceUnavailable) + http.Error(w, "503: Auth service unavailable", http.StatusServiceUnavailable) return } else { metrics.FailedLogins.Inc() if isDebugAuthEnabled(debugFlag) { log.Println("Wrong credentials for user", c.User) } - http.Error(w, "Wrong user and/or password", http.StatusUnauthorized) + http.Error(w, "401: Wrong user and/or password", http.StatusUnauthorized) return } } -- cgit v1.2.3