aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/mattn/go-sqlite3/sqlite3_trace.go
blob: a75f52ab6b3615edad99dc0d91a887e93e4818f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build trace

package sqlite3

/*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h>

void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
*/
import "C"

import (
	"errors"
	"fmt"
	"reflect"
	"strings"
	"sync"
	"unsafe"
)

// Trace... constants identify the possible events causing callback invocation.
// Values are same as the corresponding SQLite Trace Event Codes.
const (
	TraceStmt    = C.SQLITE_TRACE_STMT
	TraceProfile = C.SQLITE_TRACE_PROFILE
	TraceRow     = C.SQLITE_TRACE_ROW
	TraceClose   = C.SQLITE_TRACE_CLOSE
)

type TraceInfo struct {
	// Pack together the shorter fields, to keep the struct smaller.
	// On a 64-bit machine there would be padding
	// between EventCode and ConnHandle; having AutoCommit here is "free":
	EventCode  uint32
	AutoCommit bool
	ConnHandle uintptr

	// Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE:
	// identifier for a prepared statement:
	StmtHandle uintptr

	// Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT:
	// (1) either the unexpanded SQL text of the prepared statement, or
	//     an SQL comment that indicates the invocation of a trigger;
	// (2) expanded SQL, if requested and if (1) is not an SQL comment.
	StmtOrTrigger string
	ExpandedSQL   string // only if requested (TraceConfig.WantExpandedSQL = true)

	// filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE:
	// estimated number of nanoseconds that the prepared statement took to run:
	RunTimeNanosec int64

	DBError Error
}

// TraceUserCallback gives the signature for a trace function
// provided by the user (Go application programmer).
// SQLite 3.14 documentation (as of September 2, 2016)
// for SQL Trace Hook = sqlite3_trace_v2():
// The integer return value from the callback is currently ignored,
// though this may change in future releases. Callback implementations
// should return zero to ensure future compatibility.
type TraceUserCallback func(TraceInfo) int

type TraceConfig struct {
	Callback        TraceUserCallback
	EventMask       C.uint
	WantExpandedSQL bool
}

func fillDBError(dbErr *Error, db *C.sqlite3) {
	// See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016)
	dbErr.Code = ErrNo(C.sqlite3_errcode(db))
	dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db))
	dbErr.err = C.GoString(C.sqlite3_errmsg(db))
}

func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
	if pStmt == nil {
		panic("No SQLite statement pointer in P arg of trace_v2 callback")
	}

	expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt))
	if expSQLiteCStr == nil {
		fillDBError(&info.DBError, db)
		return
	}
	info.ExpandedSQL = C.GoString(expSQLiteCStr)
}

//export traceCallbackTrampoline
func traceCallbackTrampoline(
	traceEventCode C.uint,
	// Parameter named 'C' in SQLite docs = Context given at registration:
	ctx unsafe.Pointer,
	// Parameter named 'P' in SQLite docs (Primary event data?):
	p unsafe.Pointer,
	// Parameter named 'X' in SQLite docs (eXtra event data?):
	xValue unsafe.Pointer) C.int {

	if ctx == nil {
		panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
	}

	contextDB := (*C.sqlite3)(ctx)
	connHandle := uintptr(ctx)

	var traceConf TraceConfig
	var found bool
	if traceEventCode == TraceClose {
		// clean up traceMap: 'pop' means get and delete
		traceConf, found = popTraceMapping(connHandle)
	} else {
		traceConf, found = lookupTraceMapping(connHandle)
	}

	if !found {
		panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)",
			connHandle, traceEventCode))
	}

	var info TraceInfo

	info.EventCode = uint32(traceEventCode)
	info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0)
	info.ConnHandle = connHandle

	switch traceEventCode {
	case TraceStmt:
		info.StmtHandle = uintptr(p)

		var xStr string
		if xValue != nil {
			xStr = C.GoString((*C.char)(xValue))
		}
		info.StmtOrTrigger = xStr
		if !strings.HasPrefix(xStr, "--") {
			// Not SQL comment, therefore the current event
			// is not related to a trigger.
			// The user might want to receive the expanded SQL;
			// let's check:
			if traceConf.WantExpandedSQL {
				fillExpandedSQL(&info, contextDB, p)
			}
		}

	case TraceProfile:
		info.StmtHandle = uintptr(p)

		if xValue == nil {
			panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event")
		}

		info.RunTimeNanosec = *(*int64)(xValue)

		// sample the error //TODO: is it safe? is it useful?
		fillDBError(&info.DBError, contextDB)

	case TraceRow:
		info.StmtHandle = uintptr(p)

	case TraceClose:
		handle := uintptr(p)
		if handle != info.ConnHandle {
			panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.",
				handle, info.ConnHandle))
		}

	default:
		// Pass unsupported events to the user callback (if configured);
		// let the user callback decide whether to panic or ignore them.
	}

	// Do not execute user callback when the event was not requested by user!
	// Remember that the Close event is always selected when
	// registering this callback trampoline with SQLite --- for cleanup.
	// In the future there may be more events forced to "selected" in SQLite
	// for the driver's needs.
	if traceConf.EventMask&traceEventCode == 0 {
		return 0
	}

	r := 0
	if traceConf.Callback != nil {
		r = traceConf.Callback(info)
	}
	return C.int(r)
}

type traceMapEntry struct {
	config TraceConfig
}

var traceMapLock sync.Mutex
var traceMap = make(map[uintptr]traceMapEntry)

func addTraceMapping(connHandle uintptr, traceConf TraceConfig) {
	traceMapLock.Lock()
	defer traceMapLock.Unlock()

	oldEntryCopy, found := traceMap[connHandle]
	if found {
		panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).",
			traceConf, connHandle, oldEntryCopy.config))
	}
	traceMap[connHandle] = traceMapEntry{config: traceConf}
	fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle)
}

func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) {
	traceMapLock.Lock()
	defer traceMapLock.Unlock()

	entryCopy, found := traceMap[connHandle]
	return entryCopy.config, found
}

// 'pop' = get and delete from map before returning the value to the caller
func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
	traceMapLock.Lock()
	defer traceMapLock.Unlock()

	entryCopy, found := traceMap[connHandle]
	if found {
		delete(traceMap, connHandle)
		fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config)
	}
	return entryCopy.config, found
}

// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
	var ai aggInfo
	ai.constructor = reflect.ValueOf(impl)
	t := ai.constructor.Type()
	if t.Kind() != reflect.Func {
		return errors.New("non-function passed to RegisterAggregator")
	}
	if t.NumOut() != 1 && t.NumOut() != 2 {
		return errors.New("SQLite aggregator constructors must return 1 or 2 values")
	}
	if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
		return errors.New("Second return value of SQLite function must be error")
	}
	if t.NumIn() != 0 {
		return errors.New("SQLite aggregator constructors must not have arguments")
	}

	agg := t.Out(0)
	switch agg.Kind() {
	case reflect.Ptr, reflect.Interface:
	default:
		return errors.New("SQlite aggregator constructor must return a pointer object")
	}
	stepFn, found := agg.MethodByName("Step")
	if !found {
		return errors.New("SQlite aggregator doesn't have a Step() function")
	}
	step := stepFn.Type
	if step.NumOut() != 0 && step.NumOut() != 1 {
		return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
	}
	if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
		return errors.New("type of SQlite aggregator Step() return value must be error")
	}

	stepNArgs := step.NumIn()
	start := 0
	if agg.Kind() == reflect.Ptr {
		// Skip over the method receiver
		stepNArgs--
		start++
	}
	if step.IsVariadic() {
		stepNArgs--
	}
	for i := start; i < start+stepNArgs; i++ {
		conv, err := callbackArg(step.In(i))
		if err != nil {
			return err
		}
		ai.stepArgConverters = append(ai.stepArgConverters, conv)
	}
	if step.IsVariadic() {
		conv, err := callbackArg(t.In(start + stepNArgs).Elem())
		if err != nil {
			return err
		}
		ai.stepVariadicConverter = conv
		// Pass -1 to sqlite so that it allows any number of
		// arguments. The call helper verifies that the minimum number
		// of arguments is present for variadic functions.
		stepNArgs = -1
	}

	doneFn, found := agg.MethodByName("Done")
	if !found {
		return errors.New("SQlite aggregator doesn't have a Done() function")
	}
	done := doneFn.Type
	doneNArgs := done.NumIn()
	if agg.Kind() == reflect.Ptr {
		// Skip over the method receiver
		doneNArgs--
	}
	if doneNArgs != 0 {
		return errors.New("SQlite aggregator Done() function must have no arguments")
	}
	if done.NumOut() != 1 && done.NumOut() != 2 {
		return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
	}
	if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
		return errors.New("second return value of SQLite aggregator Done() function must be error")
	}

	conv, err := callbackRet(done.Out(0))
	if err != nil {
		return err
	}
	ai.doneRetConverter = conv
	ai.active = make(map[int64]reflect.Value)
	ai.next = 1

	// ai must outlast the database connection, or we'll have dangling pointers.
	c.aggregators = append(c.aggregators, &ai)

	cname := C.CString(name)
	defer C.free(unsafe.Pointer(cname))
	opts := C.SQLITE_UTF8
	if pure {
		opts |= C.SQLITE_DETERMINISTIC
	}
	rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
	if rv != C.SQLITE_OK {
		return c.lastError()
	}
	return nil
}

// SetTrace installs or removes the trace callback for the given database connection.
// It's not named 'RegisterTrace' because only one callback can be kept and called.
// Calling SetTrace a second time on same database connection
// overrides (cancels) any prior callback and all its settings:
// event mask, etc.
func (c *SQLiteConn) SetTrace(requested *TraceConfig) error {
	connHandle := uintptr(unsafe.Pointer(c.db))

	_, _ = popTraceMapping(connHandle)

	if requested == nil {
		// The traceMap entry was deleted already by popTraceMapping():
		// can disable all events now, no need to watch for TraceClose.
		err := c.setSQLiteTrace(0)
		return err
	}

	reqCopy := *requested

	// Disable potentially expensive operations
	// if their result will not be used. We are doing this
	// just in case the caller provided nonsensical input.
	if reqCopy.EventMask&TraceStmt == 0 {
		reqCopy.WantExpandedSQL = false
	}

	addTraceMapping(connHandle, reqCopy)

	// The callback trampoline function does cleanup on Close event,
	// regardless of the presence or absence of the user callback.
	// Therefore it needs the Close event to be selected:
	actualEventMask := uint(reqCopy.EventMask | TraceClose)
	err := c.setSQLiteTrace(actualEventMask)
	return err
}

func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error {
	rv := C.sqlite3_trace_v2(c.db,
		C.uint(sqliteEventMask),
		(*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)),
		unsafe.Pointer(c.db)) // Fourth arg is same as first: we are
	// passing the database connection handle as callback context.

	if rv != C.SQLITE_OK {
		return c.lastError()
	}
	return nil
}