From d5fbc0af2488b3908f3e90d6ab7a20a8716cecc2 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 7 Jan 2026 12:02:40 -0500 Subject: [PATCH] cgosqlite: add optional hook to verify ColumnBlob bytes are unmodified If a hook is given, then a `runtime.AddCleanup` is registered for the []byte returned from ColumnBlob that verifies whether it's been modified. If so, the hook function is called with the corresponding query. Also, add tests to confirm that this works as expected. Updates tailscale/corp#35671 Signed-off-by: Andrew Dunham --- cgosqlite/cgosqlite.go | 76 ++++++++++- cgosqlite/cgosqlite_test.go | 245 ++++++++++++++++++++++++++++++------ go.mod | 2 +- 3 files changed, 278 insertions(+), 45 deletions(-) diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 04a1f14..58d44bb 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -50,6 +50,8 @@ package cgosqlite */ import "C" import ( + "bytes" + "runtime" "sync" "sync/atomic" "time" @@ -76,6 +78,27 @@ func SetAlwaysCopyBlob(copy bool) { alwaysCopyBlob.Store(copy) } +var columnBlobModifiedHook atomic.Pointer[func(query string)] + +// SetColumnBlobModifiedHook sets a function to be called (in a new goroutine) +// whenever a []byte returned from [Stmt.ColumBlob] is detected as modified. +// The hook receives the SQL query that created the statement. +// +// Setting a non-nil hook enables verification by attaching a cleanup function +// to each returned slice that compares the final contents against the +// original. Pass nil to disable. +// +// As a necessary side effect, this function causes [Stmt.ColumnBlob] to always +// copy the blob data, to ensure that the comparison in the cleanup function is +// valid, similar to SetAlwaysCopyBlob. +func SetColumnBlobModifiedHook(hook func(query string)) { + if hook == nil { + columnBlobModifiedHook.Store(nil) + } else { + columnBlobModifiedHook.Store(&hook) + } +} + func init() { C.sqlite3_initialize() } @@ -92,6 +115,7 @@ type Stmt struct { db *DB stmt *C.sqlite3_stmt start C.struct_timespec + query string // original query, stored for columnBlobModifiedHook // used as scratch space when calling into cgo rowid, changes C.sqlite3_int64 @@ -200,7 +224,7 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite return nil, "", err } remainingQuery = query[len(query)-int(C.strlen(csqlTail)):] - return &Stmt{db: db, stmt: cstmt}, remainingQuery, nil + return &Stmt{db: db, stmt: cstmt, query: query}, remainingQuery, nil } func (db *DB) DisableFunction(name string, numArgs int) error { @@ -377,6 +401,23 @@ func (stmt *Stmt) ColumnText(col int) string { return C.GoStringN(str, n) } +// blobCheckArg is the argument passed to the cleanup function for verifying +// that the slice returned from ColumnBlob was not modified. +// +// TODO: We use uintptr instead of []byte to avoid keeping the slice alive +// (which would prevent the cleanup from running); is that right? +type blobCheckArg struct { + original []byte // copy of original data + ptr uintptr // pointer to first byte of slice + len int // length of slice + query string // SQL query that produced the blob + hook func(query string) // hook to call if modified +} + +// blobCheckHook, if non-nil, is called after each blob check Cleanup function +// executes. This allows deterministic tests. +var blobCheckHook func() + func (stmt *Stmt) ColumnBlob(col int) []byte { res := C.sqlite3_column_blob(stmt.stmt, C.int(col)) if res == nil { @@ -384,9 +425,38 @@ func (stmt *Stmt) ColumnBlob(col int) []byte { } n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col))) slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n) - if alwaysCopyBlob.Load() { - return append([]byte(nil), slice...) + + // In addition to copying if the alwaysCopyBlob flag is set, also copy + // if there is a columnBlobModifiedHook set. This is because a + // runtime.AddCleanup callback executes at some indeterminate time in + // the future, after the point which SQLite might have reused the + // underlying memory. Copying now ensures that the comparison in the + // cleanup function is valid. + hookPtr := columnBlobModifiedHook.Load() + if alwaysCopyBlob.Load() || hookPtr != nil { + slice = append([]byte(nil), slice...) } + + if hookPtr != nil && n > 0 { + arg := blobCheckArg{ + original: bytes.Clone(slice), + ptr: uintptr(unsafe.Pointer(&slice[0])), + len: n, + query: stmt.query, + hook: *hookPtr, + } + runtime.AddCleanup(&slice[0], func(a blobCheckArg) { + current := unsafe.Slice((*byte)(unsafe.Pointer(a.ptr)), a.len) + if !bytes.Equal(current, a.original) { + go a.hook(a.query) + } + + if blobCheckHook != nil { + blobCheckHook() + } + }, arg) + } + return slice } diff --git a/cgosqlite/cgosqlite_test.go b/cgosqlite/cgosqlite_test.go index 9b4f94f..520b949 100644 --- a/cgosqlite/cgosqlite_test.go +++ b/cgosqlite/cgosqlite_test.go @@ -3,7 +3,11 @@ package cgosqlite import ( "bytes" "path/filepath" + "runtime" + "sync" + "sync/atomic" "testing" + "time" "github.com/tailscale/sqlite/sqliteh" ) @@ -28,47 +32,13 @@ func TestColumnBlob(t *testing.T) { } defer db.Close() - mustRun := func(sql string) { - t.Helper() - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - t.Fatalf("Prepare %q: %v", sql, err) - } - if _, err := stmt.Step(nil); err != nil { - t.Fatalf("Step: %v", err) - } - if err := stmt.Finalize(); err != nil { - t.Fatalf("Finalize: %v", err) - } - } - - mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") - mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`) - mustRun(`INSERT INTO t (id, data) VALUES (2, '')`) - mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`) - - // queryRow runs the given query and returns the *Stmt for the first row. - queryRow := func(t *testing.T, sql string) sqliteh.Stmt { - t.Helper() - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - stmt.Finalize() - }) - row, err := stmt.Step(nil) - if err != nil { - t.Fatal(err) - } - if !row { - t.Fatal("expected a row") - } - return stmt - } + mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") + mustRun(t, db, `INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`) + mustRun(t, db, `INSERT INTO t (id, data) VALUES (2, '')`) + mustRun(t, db, `INSERT INTO t (id, data) VALUES (3, NULL)`) t.Run("WithData", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 1") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 1") data := stmt.ColumnBlob(0) const want = "HELLOHELLOHELLOHELLOHELLOHELLO99" @@ -78,7 +48,7 @@ func TestColumnBlob(t *testing.T) { }) t.Run("EmptyBlob", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 2") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 2") data := stmt.ColumnBlob(0) if len(data) != 0 { t.Fatalf("got %d bytes, want 0 bytes", len(data)) @@ -91,7 +61,7 @@ func TestColumnBlob(t *testing.T) { }) t.Run("NullBlob", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 3") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 3") data := stmt.ColumnBlob(0) if data != nil { t.Fatalf("got %q, want nil", data) @@ -100,3 +70,196 @@ func TestColumnBlob(t *testing.T) { }) } } + +func TestColumnBlobModifiedHook(t *testing.T) { + // Disable the "always copy blob" option to test just the hook behavior + SetAlwaysCopyBlob(false) + + // Write to this channel every time a cleanup function executes, so we + // can ensure they've run. + checkRun := make(chan struct{}, 10_000) // high enough to never block + blobCheckHook = func() { + checkRun <- struct{}{} + } + t.Cleanup(func() { + blobCheckHook = nil + }) + + // waitForCleanup waits for one cleanup to run. + waitForCleanup := func() { + timedOut := time.After(10 * time.Second) + for { + runtime.GC() + runtime.Gosched() + + select { + case <-checkRun: + return + case <-t.Context().Done(): + t.Fatal("test context done while waiting for cleanup") + case <-timedOut: + t.Fatal("timeout waiting for cleanup") + case <-time.After(10 * time.Millisecond): + // retry + } + } + } + + // Open a test database + db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Create a table with some blob data + mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") + + // Use a blob larger than 16 bytes to avoid tiny object optimization which + // can prevent cleanups from running (as mentioned in the documentation + // for [runtime.AddCleanup]). + mustRun(t, db, "INSERT INTO t (id, data) VALUES (1, CAST('HELLOHELLOHELLOHELLOHELLOHELLO99' AS BLOB))") + + const testQuery = "SELECT data FROM t WHERE id = 1" + + t.Run("UnmodifiedSliceDoesNotCallHook", func(t *testing.T) { + var hookCalls atomic.Int64 + SetColumnBlobModifiedHook(func(query string) { + hookCalls.Add(1) + }) + defer SetColumnBlobModifiedHook(nil) + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + // Don't modify data, just let it go out of scope + runtime.KeepAlive(data) + }() + + waitForCleanup() + if got := hookCalls.Load(); got != 0 { + t.Errorf("hook called %d times, want 0", got) + } + }) + + t.Run("ModifiedSliceCallsHook", func(t *testing.T) { + var ( + hookCalls atomic.Int64 + receivedQuery atomic.Pointer[string] + + calledOnce sync.Once + called = make(chan struct{}) + ) + SetColumnBlobModifiedHook(func(query string) { + hookCalls.Add(1) + receivedQuery.Store(&query) + calledOnce.Do(func() { close(called) }) + }) + defer SetColumnBlobModifiedHook(nil) + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + // Modify the data to trigger our hook. + data[0] = byte((int(data[0]) + 1) % 256) + + runtime.KeepAlive(data) + }() + + waitForCleanup() + <-called // need to synchronize separately since it's in another goroutine + + if got := hookCalls.Load(); got != 1 { + t.Errorf("hook called %d times, want 1", got) + } + if q := receivedQuery.Load(); q == nil || *q != testQuery { + got := "" + if q != nil { + got = *q + } + t.Errorf("hook received query %q, want %q", got, testQuery) + } + }) + + t.Run("NilHook", func(t *testing.T) { + SetColumnBlobModifiedHook(nil) + + // Ensure we start with an empty channel. + drain: + for { + select { + case <-checkRun: + default: + break drain + } + } + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + data[0] = 'Y' + + runtime.KeepAlive(data) + }() + + // Spin for a bit to try and trigger any cleanups to be executed. + for i := 0; i < 10; i++ { + runtime.GC() + runtime.Gosched() + time.Sleep(10 * time.Millisecond) + } + + // We expect nothing in the channel, as no hook is set. + select { + case <-checkRun: + t.Fatal("unexpected cleanup hook call") + default: + } + }) +} + +func mustRun(t *testing.T, db sqliteh.DB, sql string) { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatalf("Prepare %q: %v", sql, err) + } + if _, err := stmt.Step(nil); err != nil { + t.Fatalf("Step: %v", err) + } + if err := stmt.Finalize(); err != nil { + t.Fatalf("Finalize: %v", err) + } +} + +// queryRow runs the given query and returns the *Stmt for the first row. +func queryRow(t *testing.T, db sqliteh.DB, sql string) sqliteh.Stmt { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + stmt.Finalize() + }) + row, err := stmt.Step(nil) + if err != nil { + t.Fatal(err) + } + if !row { + t.Fatal("expected a row") + } + return stmt +} diff --git a/go.mod b/go.mod index f3ecd7d..3573c78 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/tailscale/sqlite -go 1.21 +go 1.24