summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/parser.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/parser.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/parser.go40
1 files changed, 35 insertions, 5 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/parser.go b/vendor/github.com/denisenkom/go-mssqldb/parser.go
index 9e37c16a65..8021ca603c 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/parser.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/parser.go
@@ -11,6 +11,9 @@ type parser struct {
w bytes.Buffer
paramCount int
paramMax int
+
+ // using map as a set
+ namedParams map[string]bool
}
func (p *parser) next() (rune, bool) {
@@ -39,13 +42,14 @@ type stateFunc func(*parser) stateFunc
func parseParams(query string) (string, int) {
p := &parser{
- r: bytes.NewReader([]byte(query)),
+ r: bytes.NewReader([]byte(query)),
+ namedParams: map[string]bool{},
}
state := parseNormal
for state != nil {
state = state(p)
}
- return p.w.String(), p.paramMax
+ return p.w.String(), p.paramMax + len(p.namedParams)
}
func parseNormal(p *parser) stateFunc {
@@ -55,7 +59,7 @@ func parseNormal(p *parser) stateFunc {
return nil
}
if ch == '?' {
- return parseParameter
+ return parseOrdinalParameter
} else if ch == '$' || ch == ':' {
ch2, ok := p.next()
if !ok {
@@ -64,7 +68,9 @@ func parseNormal(p *parser) stateFunc {
}
p.unread()
if ch2 >= '0' && ch2 <= '9' {
- return parseParameter
+ return parseOrdinalParameter
+ } else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' {
+ return parseNamedParameter
}
}
p.write(ch)
@@ -83,7 +89,7 @@ func parseNormal(p *parser) stateFunc {
}
}
-func parseParameter(p *parser) stateFunc {
+func parseOrdinalParameter(p *parser) stateFunc {
var paramN int
var ok bool
for {
@@ -113,6 +119,30 @@ func parseParameter(p *parser) stateFunc {
return parseNormal
}
+func parseNamedParameter(p *parser) stateFunc {
+ var paramName string
+ var ok bool
+ for {
+ var ch rune
+ ch, ok = p.next()
+ if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') {
+ paramName = paramName + string(ch)
+ } else {
+ break
+ }
+ }
+ if ok {
+ p.unread()
+ }
+ p.namedParams[paramName] = true
+ p.w.WriteString("@")
+ p.w.WriteString(paramName)
+ if !ok {
+ return nil
+ }
+ return parseNormal
+}
+
func parseQuote(p *parser) stateFunc {
for {
ch, ok := p.next()