summaryrefslogtreecommitdiff
path: root/vendor/github.com/smartystreets/assertions/type.go
blob: d2d1dc864b2d0835b8514d26a26f0d928deff073 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package assertions

import (
	"fmt"
	"reflect"
)

// ShouldHaveSameTypeAs receives exactly two parameters and compares their underlying types for equality.
func ShouldHaveSameTypeAs(actual interface{}, expected ...interface{}) string {
	if fail := need(1, expected); fail != success {
		return fail
	}

	first := reflect.TypeOf(actual)
	second := reflect.TypeOf(expected[0])

	if first != second {
		return serializer.serialize(second, first, fmt.Sprintf(shouldHaveBeenA, actual, second, first))
	}

	return success
}

// ShouldNotHaveSameTypeAs receives exactly two parameters and compares their underlying types for inequality.
func ShouldNotHaveSameTypeAs(actual interface{}, expected ...interface{}) string {
	if fail := need(1, expected); fail != success {
		return fail
	}

	first := reflect.TypeOf(actual)
	second := reflect.TypeOf(expected[0])

	if (actual == nil && expected[0] == nil) || first == second {
		return fmt.Sprintf(shouldNotHaveBeenA, actual, second)
	}
	return success
}

// ShouldImplement receives exactly two parameters and ensures
// that the first implements the interface type of the second.
func ShouldImplement(actual interface{}, expectedList ...interface{}) string {
	if fail := need(1, expectedList); fail != success {
		return fail
	}

	expected := expectedList[0]
	if fail := ShouldBeNil(expected); fail != success {
		return shouldCompareWithInterfacePointer
	}

	if fail := ShouldNotBeNil(actual); fail != success {
		return shouldNotBeNilActual
	}

	var actualType reflect.Type
	if reflect.TypeOf(actual).Kind() != reflect.Ptr {
		actualType = reflect.PtrTo(reflect.TypeOf(actual))
	} else {
		actualType = reflect.TypeOf(actual)
	}

	expectedType := reflect.TypeOf(expected)
	if fail := ShouldNotBeNil(expectedType); fail != success {
		return shouldCompareWithInterfacePointer
	}

	expectedInterface := expectedType.Elem()

	if !actualType.Implements(expectedInterface) {
		return fmt.Sprintf(shouldHaveImplemented, expectedInterface, actualType)
	}
	return success
}

// ShouldNotImplement receives exactly two parameters and ensures
// that the first does NOT implement the interface type of the second.
func ShouldNotImplement(actual interface{}, expectedList ...interface{}) string {
	if fail := need(1, expectedList); fail != success {
		return fail
	}

	expected := expectedList[0]
	if fail := ShouldBeNil(expected); fail != success {
		return shouldCompareWithInterfacePointer
	}

	if fail := ShouldNotBeNil(actual); fail != success {
		return shouldNotBeNilActual
	}

	var actualType reflect.Type
	if reflect.TypeOf(actual).Kind() != reflect.Ptr {
		actualType = reflect.PtrTo(reflect.TypeOf(actual))
	} else {
		actualType = reflect.TypeOf(actual)
	}

	expectedType := reflect.TypeOf(expected)
	if fail := ShouldNotBeNil(expectedType); fail != success {
		return shouldCompareWithInterfacePointer
	}

	expectedInterface := expectedType.Elem()

	if actualType.Implements(expectedInterface) {
		return fmt.Sprintf(shouldNotHaveImplemented, actualType, expectedInterface)
	}
	return success
}

// ShouldBeError asserts that the first argument implements the error interface.
// It also compares the first argument against the second argument if provided
// (which must be an error message string or another error value).
func ShouldBeError(actual interface{}, expected ...interface{}) string {
	if fail := atMost(1, expected); fail != success {
		return fail
	}

	if !isError(actual) {
		return fmt.Sprintf(shouldBeError, reflect.TypeOf(actual))
	}

	if len(expected) == 0 {
		return success
	}

	if expected := expected[0]; !isString(expected) && !isError(expected) {
		return fmt.Sprintf(shouldBeErrorInvalidComparisonValue, reflect.TypeOf(expected))
	}
	return ShouldEqual(fmt.Sprint(actual), fmt.Sprint(expected[0]))
}

func isString(value interface{}) bool { _, ok := value.(string); return ok }
func isError(value interface{}) bool  { _, ok := value.(error); return ok }