aboutsummaryrefslogtreecommitdiffstats
path: root/modules/middleware/binding/binding.go
blob: bf8ab52b2b096c7813cc7549a4bf18902bf62b68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
// Copyright 2013 The Martini Contrib Authors. All rights reserved.
// Copyright 2014 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.

package binding

import (
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"reflect"
	"regexp"
	"strconv"
	"strings"
	"unicode/utf8"

	"github.com/go-martini/martini"
)

/*
	To the land of Middle-ware Earth:

		One func to rule them all,
		One func to find them,
		One func to bring them all,
		And in this package BIND them.
*/

// Bind accepts a copy of an empty struct and populates it with
// values from the request (if deserialization is successful). It
// wraps up the functionality of the Form and Json middleware
// according to the Content-Type of the request, and it guesses
// if no Content-Type is specified. Bind invokes the ErrorHandler
// middleware to bail out if errors occurred. If you want to perform
// your own error handling, use Form or Json middleware directly.
// An interface pointer can be added as a second argument in order
// to map the struct to a specific interface.
func Bind(obj interface{}, ifacePtr ...interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		contentType := req.Header.Get("Content-Type")

		if strings.Contains(contentType, "form-urlencoded") {
			context.Invoke(Form(obj, ifacePtr...))
		} else if strings.Contains(contentType, "multipart/form-data") {
			context.Invoke(MultipartForm(obj, ifacePtr...))
		} else if strings.Contains(contentType, "json") {
			context.Invoke(Json(obj, ifacePtr...))
		} else {
			context.Invoke(Json(obj, ifacePtr...))
			if getErrors(context).Count() > 0 {
				context.Invoke(Form(obj, ifacePtr...))
			}
		}

		context.Invoke(ErrorHandler)
	}
}

// BindIgnErr will do the exactly same thing as Bind but without any
// error handling, which user has freedom to deal with them.
// This allows user take advantages of validation.
func BindIgnErr(obj interface{}, ifacePtr ...interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		contentType := req.Header.Get("Content-Type")

		if strings.Contains(contentType, "form-urlencoded") {
			context.Invoke(Form(obj, ifacePtr...))
		} else if strings.Contains(contentType, "multipart/form-data") {
			context.Invoke(MultipartForm(obj, ifacePtr...))
		} else if strings.Contains(contentType, "json") {
			context.Invoke(Json(obj, ifacePtr...))
		} else {
			context.Invoke(Json(obj, ifacePtr...))
			if getErrors(context).Count() > 0 {
				context.Invoke(Form(obj, ifacePtr...))
			}
		}
	}
}

// Form is middleware to deserialize form-urlencoded data from the request.
// It gets data from the form-urlencoded body, if present, or from the
// query string. It uses the http.Request.ParseForm() method
// to perform deserialization, then reflection is used to map each field
// into the struct with the proper type. Structs with primitive slice types
// (bool, float, int, string) can support deserialization of repeated form
// keys, for example: key=val1&key=val2&key=val3
// An interface pointer can be added as a second argument in order
// to map the struct to a specific interface.
func Form(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		ensureNotPointer(formStruct)
		formStruct := reflect.New(reflect.TypeOf(formStruct))
		errors := newErrors()
		parseErr := req.ParseForm()

		// Format validation of the request body or the URL would add considerable overhead,
		// and ParseForm does not complain when URL encoding is off.
		// Because an empty request body or url can also mean absence of all needed values,
		// it is not in all cases a bad request, so let's return 422.
		if parseErr != nil {
			errors.Overall[BindingDeserializationError] = parseErr.Error()
		}

		mapForm(formStruct, req.Form, errors)

		validateAndMap(formStruct, context, errors, ifacePtr...)
	}
}

func MultipartForm(formStruct interface{}, ifacePtr ...interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		ensureNotPointer(formStruct)
		formStruct := reflect.New(reflect.TypeOf(formStruct))
		errors := newErrors()

		// Workaround for multipart forms returning nil instead of an error
		// when content is not multipart
		// https://code.google.com/p/go/issues/detail?id=6334
		multipartReader, err := req.MultipartReader()
		if err != nil {
			errors.Overall[BindingDeserializationError] = err.Error()
		} else {
			form, parseErr := multipartReader.ReadForm(MaxMemory)

			if parseErr != nil {
				errors.Overall[BindingDeserializationError] = parseErr.Error()
			}

			req.MultipartForm = form
		}

		mapForm(formStruct, req.MultipartForm.Value, errors)

		validateAndMap(formStruct, context, errors, ifacePtr...)
	}
}

// Json is middleware to deserialize a JSON payload from the request
// into the struct that is passed in. The resulting struct is then
// validated, but no error handling is actually performed here.
// An interface pointer can be added as a second argument in order
// to map the struct to a specific interface.
func Json(jsonStruct interface{}, ifacePtr ...interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		ensureNotPointer(jsonStruct)
		jsonStruct := reflect.New(reflect.TypeOf(jsonStruct))
		errors := newErrors()

		if req.Body != nil {
			defer req.Body.Close()
		}

		if err := json.NewDecoder(req.Body).Decode(jsonStruct.Interface()); err != nil && err != io.EOF {
			errors.Overall[BindingDeserializationError] = err.Error()
		}

		validateAndMap(jsonStruct, context, errors, ifacePtr...)
	}
}

// Validate is middleware to enforce required fields. If the struct
// passed in is a Validator, then the user-defined Validate method
// is executed, and its errors are mapped to the context. This middleware
// performs no error handling: it merely detects them and maps them.
func Validate(obj interface{}) martini.Handler {
	return func(context martini.Context, req *http.Request) {
		errors := newErrors()
		validateStruct(errors, obj)

		if validator, ok := obj.(Validator); ok {
			validator.Validate(errors, req, context)
		}
		context.Map(*errors)
	}
}

var (
	alphaDashPattern    = regexp.MustCompile("[^\\d\\w-_]")
	alphaDashDotPattern = regexp.MustCompile("[^\\d\\w-_\\.]")
	emailPattern        = regexp.MustCompile("[\\w!#$%&'*+/=?^_`{|}~-]+(?:\\.[\\w!#$%&'*+/=?^_`{|}~-]+)*@(?:[\\w](?:[\\w-]*[\\w])?\\.)+[a-zA-Z0-9](?:[\\w-]*[\\w])?")
	urlPattern          = regexp.MustCompile(`(http|https):\/\/[\w\-_]+(\.[\w\-_]+)+([\w\-\.,@?^=%&:/~\+#]*[\w\-\@?^=%&/~\+#])?`)
)

func validateStruct(errors *Errors, obj interface{}) {
	typ := reflect.TypeOf(obj)
	val := reflect.ValueOf(obj)

	if typ.Kind() == reflect.Ptr {
		typ = typ.Elem()
		val = val.Elem()
	}

	for i := 0; i < typ.NumField(); i++ {
		field := typ.Field(i)

		// Allow ignored fields in the struct
		if field.Tag.Get("form") == "-" {
			continue
		}

		fieldValue := val.Field(i).Interface()
		if field.Type.Kind() == reflect.Struct {
			validateStruct(errors, fieldValue)
			continue
		}

		zero := reflect.Zero(field.Type).Interface()

		// Match rules.
		for _, rule := range strings.Split(field.Tag.Get("binding"), ";") {
			if len(rule) == 0 {
				continue
			}

			switch {
			case rule == "Required":
				if reflect.DeepEqual(zero, fieldValue) {
					errors.Fields[field.Name] = BindingRequireError
					break
				}
			case rule == "AlphaDash":
				if alphaDashPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
					errors.Fields[field.Name] = BindingAlphaDashError
					break
				}
			case rule == "AlphaDashDot":
				if alphaDashDotPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
					errors.Fields[field.Name] = BindingAlphaDashDotError
					break
				}
			case strings.HasPrefix(rule, "MinSize("):
				min, err := strconv.Atoi(rule[8 : len(rule)-1])
				if err != nil {
					errors.Overall["MinSize"] = err.Error()
					break
				}
				if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) < min {
					errors.Fields[field.Name] = BindingMinSizeError
					break
				}
				v := reflect.ValueOf(fieldValue)
				if v.Kind() == reflect.Slice && v.Len() < min {
					errors.Fields[field.Name] = BindingMinSizeError
					break
				}
			case strings.HasPrefix(rule, "MaxSize("):
				max, err := strconv.Atoi(rule[8 : len(rule)-1])
				if err != nil {
					errors.Overall["MaxSize"] = err.Error()
					break
				}
				if str, ok := fieldValue.(string); ok && utf8.RuneCountInString(str) > max {
					errors.Fields[field.Name] = BindingMaxSizeError
					break
				}
				v := reflect.ValueOf(fieldValue)
				if v.Kind() == reflect.Slice && v.Len() > max {
					errors.Fields[field.Name] = BindingMinSizeError
					break
				}
			case rule == "Email":
				if !emailPattern.MatchString(fmt.Sprintf("%v", fieldValue)) {
					errors.Fields[field.Name] = BindingEmailError
					break
				}
			case rule == "Url":
				str := fmt.Sprintf("%v", fieldValue)
				if len(str) == 0 {
					continue
				} else if !urlPattern.MatchString(str) {
					errors.Fields[field.Name] = BindingUrlError
					break
				}
			}
		}
	}
}

func mapForm(formStruct reflect.Value, form map[string][]string, errors *Errors) {
	typ := formStruct.Elem().Type()

	for i := 0; i < typ.NumField(); i++ {
		typeField := typ.Field(i)
		if inputFieldName := typeField.Tag.Get("form"); inputFieldName != "" {
			structField := formStruct.Elem().Field(i)
			if !structField.CanSet() {
				continue
			}

			inputValue, exists := form[inputFieldName]

			if !exists {
				continue
			}

			numElems := len(inputValue)
			if structField.Kind() == reflect.Slice && numElems > 0 {
				sliceOf := structField.Type().Elem().Kind()
				slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
				for i := 0; i < numElems; i++ {
					setWithProperType(sliceOf, inputValue[i], slice.Index(i), inputFieldName, errors)
				}
				formStruct.Elem().Field(i).Set(slice)
			} else {
				setWithProperType(typeField.Type.Kind(), inputValue[0], structField, inputFieldName, errors)
			}
		}
	}
}

// ErrorHandler simply counts the number of errors in the
// context and, if more than 0, writes a 400 Bad Request
// response and a JSON payload describing the errors with
// the "Content-Type" set to "application/json".
// Middleware remaining on the stack will not even see the request
// if, by this point, there are any errors.
// This is a "default" handler, of sorts, and you are
// welcome to use your own instead. The Bind middleware
// invokes this automatically for convenience.
func ErrorHandler(errs Errors, resp http.ResponseWriter) {
	if errs.Count() > 0 {
		resp.Header().Set("Content-Type", "application/json; charset=utf-8")
		if _, ok := errs.Overall[BindingDeserializationError]; ok {
			resp.WriteHeader(http.StatusBadRequest)
		} else {
			resp.WriteHeader(422)
		}
		errOutput, _ := json.Marshal(errs)
		resp.Write(errOutput)
		return
	}
}

// This sets the value in a struct of an indeterminate type to the
// matching value from the request (via Form middleware) in the
// same type, so that not all deserialized values have to be strings.
// Supported types are string, int, float, and bool.
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value, nameInTag string, errors *Errors) {
	switch valueKind {
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		if val == "" {
			val = "0"
		}
		intVal, err := strconv.ParseInt(val, 10, 64)
		if err != nil {
			errors.Fields[nameInTag] = BindingIntegerTypeError
		} else {
			structField.SetInt(intVal)
		}
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		if val == "" {
			val = "0"
		}
		uintVal, err := strconv.ParseUint(val, 10, 64)
		if err != nil {
			errors.Fields[nameInTag] = BindingIntegerTypeError
		} else {
			structField.SetUint(uintVal)
		}
	case reflect.Bool:
		structField.SetBool(val == "on")
	case reflect.Float32:
		if val == "" {
			val = "0.0"
		}
		floatVal, err := strconv.ParseFloat(val, 32)
		if err != nil {
			errors.Fields[nameInTag] = BindingFloatTypeError
		} else {
			structField.SetFloat(floatVal)
		}
	case reflect.Float64:
		if val == "" {
			val = "0.0"
		}
		floatVal, err := strconv.ParseFloat(val, 64)
		if err != nil {
			errors.Fields[nameInTag] = BindingFloatTypeError
		} else {
			structField.SetFloat(floatVal)
		}
	case reflect.String:
		structField.SetString(val)
	}
}

// Don't pass in pointers to bind to. Can lead to bugs. See:
// https://github.com/codegangsta/martini-contrib/issues/40
// https://github.com/codegangsta/martini-contrib/pull/34#issuecomment-29683659
func ensureNotPointer(obj interface{}) {
	if reflect.TypeOf(obj).Kind() == reflect.Ptr {
		panic("Pointers are not accepted as binding models")
	}
}

// Performs validation and combines errors from validation
// with errors from deserialization, then maps both the
// resulting struct and the errors to the context.
func validateAndMap(obj reflect.Value, context martini.Context, errors *Errors, ifacePtr ...interface{}) {
	context.Invoke(Validate(obj.Interface()))
	errors.Combine(getErrors(context))
	context.Map(*errors)
	context.Map(obj.Elem().Interface())
	if len(ifacePtr) > 0 {
		context.MapTo(obj.Elem().Interface(), ifacePtr[0])
	}
}

func newErrors() *Errors {
	return &Errors{make(map[string]string), make(map[string]string)}
}

func getErrors(context martini.Context) Errors {
	return context.Get(reflect.TypeOf(Errors{})).Interface().(Errors)
}

type (
	// Implement the Validator interface to define your own input
	// validation before the request even gets to your application.
	// The Validate method will be executed during the validation phase.
	Validator interface {
		Validate(*Errors, *http.Request, martini.Context)
	}
)

var (
	// Maximum amount of memory to use when parsing a multipart form.
	// Set this to whatever value you prefer; default is 10 MB.
	MaxMemory = int64(1024 * 1024 * 10)
)

// Errors represents the contract of the response body when the
// binding step fails before getting to the application.
type Errors struct {
	Overall map[string]string `json:"overall"`
	Fields  map[string]string `json:"fields"`
}

// Total errors is the sum of errors with the request overall
// and errors on individual fields.
func (err Errors) Count() int {
	return len(err.Overall) + len(err.Fields)
}

func (this *Errors) Combine(other Errors) {
	for key, val := range other.Fields {
		if _, exists := this.Fields[key]; !exists {
			this.Fields[key] = val
		}
	}
	for key, val := range other.Overall {
		if _, exists := this.Overall[key]; !exists {
			this.Overall[key] = val
		}
	}
}

const (
	BindingRequireError         string = "Required"
	BindingAlphaDashError       string = "AlphaDash"
	BindingAlphaDashDotError    string = "AlphaDashDot"
	BindingMinSizeError         string = "MinSize"
	BindingMaxSizeError         string = "MaxSize"
	BindingEmailError           string = "Email"
	BindingUrlError             string = "Url"
	BindingDeserializationError string = "DeserializationError"
	BindingIntegerTypeError     string = "IntegerTypeError"
	BindingBooleanTypeError     string = "BooleanTypeError"
	BindingFloatTypeError       string = "FloatTypeError"
)