diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 926f7a6..64b63c3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,7 +1,7 @@ # This workflow will build a golang project # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go -name: go-querysql-test +name: sqlcode on: pull_request: @@ -11,19 +11,16 @@ jobs: build: runs-on: ubuntu-latest - env: - SQLSERVER_DSN: "sqlserver://127.0.0.1:1433?database=master&user id=sa&password=VippsPw1" + strategy: + matrix: + driver: ['mssql','pgsql'] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' - - - name: Start db - run: docker compose -f docker-compose.test.yml up -d + go-version: '1.25' - name: Test - # Skip the example folder because it has examples of what-not-to-do - run: go test -v $(go list ./... | grep -v './example') + run: docker compose -f docker-compose.${{ matrix.driver }}.yml run test \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fbcbdeb --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +test: test_mssql test_pgsql + + +test_mssql: + docker compose --progress plain -f docker-compose.mssql.yml run test + +test_pgsql: + docker compose --progress plain -f docker-compose.pgsql.yml run test \ No newline at end of file diff --git a/cli/cmd/build.go b/cli/cmd/build.go index 1ffdde2..9fd9d9a 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -3,6 +3,8 @@ package cmd import ( "errors" "fmt" + + mssql "github.com/microsoft/go-mssqldb" "github.com/spf13/cobra" "github.com/vippsas/sqlcode" ) @@ -23,7 +25,7 @@ var ( return err } - preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix) + preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix, &mssql.Driver{}) if err != nil { return err } diff --git a/cli/cmd/config.go b/cli/cmd/config.go index 6bebbf1..6968802 100644 --- a/cli/cmd/config.go +++ b/cli/cmd/config.go @@ -5,16 +5,17 @@ import ( "database/sql" "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/azuread" - "golang.org/x/net/proxy" "io/ioutil" "os" "path" "strings" - _ "github.com/denisenkom/go-mssqldb/azuread" - "github.com/denisenkom/go-mssqldb/msdsn" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/azuread" + "golang.org/x/net/proxy" + + _ "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) diff --git a/cli/cmd/constants.go b/cli/cmd/constants.go index a71b364..90d071a 100644 --- a/cli/cmd/constants.go +++ b/cli/cmd/constants.go @@ -20,18 +20,18 @@ var ( if err != nil { return err } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } return nil } fmt.Println("declare") - for i, c := range d.CodeBase.Declares { + for i, c := range d.CodeBase.Declares() { var prefix string if i == 0 { prefix = " " diff --git a/cli/cmd/dep.go b/cli/cmd/dep.go index c0d110c..528b5b5 100644 --- a/cli/cmd/dep.go +++ b/cli/cmd/dep.go @@ -36,16 +36,16 @@ var ( fmt.Println() err = nil } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } } - for _, c := range d.CodeBase.Creates { + for _, c := range d.CodeBase.Creates() { fmt.Println(c.QuotedName.String() + ":") if len(c.DependsOn) > 0 { fmt.Println(" Uses:") diff --git a/dbintf.go b/dbintf.go index 8257e11..1495942 100644 --- a/dbintf.go +++ b/dbintf.go @@ -3,6 +3,7 @@ package sqlcode import ( "context" "database/sql" + "database/sql/driver" ) type DB interface { @@ -11,6 +12,7 @@ type DB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row Conn(ctx context.Context) (*sql.Conn, error) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error) + Driver() driver.Driver } var _ DB = &sql.DB{} diff --git a/dbops.go b/dbops.go index 05e5a88..d63278f 100644 --- a/dbops.go +++ b/dbops.go @@ -3,11 +3,26 @@ package sqlcode import ( "context" "database/sql" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { var schemaID int - err := dbc.QueryRowContext(ctx, `select isnull(schema_id(@p1), 0)`, SchemaName(schemasuffix)).Scan(&schemaID) + + driver := dbc.Driver() + var qs string + + if _, ok := driver.(*mssql.Driver); ok { + qs = `select isnull(schema_id(@p1), 0)` + } + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select coalesce((select oid from pg_namespace where nspname = $1),0)` + } + + err := dbc.QueryRowContext(ctx, qs, SchemaName(schemasuffix)).Scan(&schemaID) if err != nil { return false, err } @@ -19,8 +34,24 @@ func Drop(ctx context.Context, dbc DB, schemasuffix string) error { if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.DropCodeSchema`, - sql.Named("schemasuffix", schemasuffix)) + + var qs string + var arg = []interface{}{} + driver := dbc.Driver() + + if _, ok := driver.(*mssql.Driver); ok { + qs = `sqlcode.DropCodeSchema` + arg = []interface{}{sql.Named("schemasuffix", schemasuffix)} + } + + if _, ok := dbc.Driver().(*stdlib.Driver); ok { + qs = `call sqlcode.dropcodeschema(@schemasuffix)` + arg = []interface{}{ + pgx.NamedArgs{"schemasuffix": schemasuffix}, + } + } + + _, err = tx.ExecContext(ctx, qs, arg...) if err != nil { _ = tx.Rollback() return err diff --git a/deployable.go b/deployable.go index 7e2b178..135fd26 100644 --- a/deployable.go +++ b/deployable.go @@ -10,7 +10,10 @@ import ( "strings" "time" - mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" ) @@ -77,24 +80,25 @@ func impersonate(ctx context.Context, dbc DB, username string, f func(conn *sql. // Upload will create and upload the schema; resulting in an error // if the schema already exists func (d *Deployable) Upload(ctx context.Context, dbc DB) error { - // First, impersonate a user with minimal privileges to get at least - // some level of sandboxing so that migration scripts can't do anything - // the caller didn't expect them to. - return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { + driver := dbc.Driver() + qs := make(map[string][]interface{}, 1) + + var uploadFunc = func(conn *sql.Conn) error { tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.CreateCodeSchema`, - sql.Named("schemasuffix", d.SchemaSuffix), - ) - if err != nil { - _ = tx.Rollback() - return err + for q, args := range qs { + _, err = tx.ExecContext(ctx, q, args...) + + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to execute (%s) with arg(%s) in schema %s: %w", q, args, d.SchemaSuffix, err) + } } - preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) + preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) if err != nil { _ = tx.Rollback() return err @@ -103,15 +107,16 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { _, err := tx.ExecContext(ctx, b.Lines) if err != nil { _ = tx.Rollback() - sqlerr, ok := err.(mssql.Error) - if !ok { - return err - } else { - return SQLUserError{ + if sqlerr, ok := err.(mssql.Error); ok { + return MSSQLUserError{ Wrapped: sqlerr, Batch: b, } } + + // TODO(ks) PGSQLUserError + return fmt.Errorf("failed to upload deployable:%s in schema:%s:%w", d.CodeBase, d.SchemaSuffix, err) + } } err = tx.Commit() @@ -123,8 +128,36 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { return nil - }) + } + + if _, ok := driver.(*mssql.Driver); ok { + // First, impersonate a user with minimal privileges to get at least + // some level of sandboxing so that migration scripts can't do anything + // the caller didn't expect them to. + qs["sqlcode.CreateCodeSchema"] = []interface { + }{ + sql.Named("schemasuffix", d.SchemaSuffix), + } + + return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", uploadFunc) + } + + if _, ok := driver.(*stdlib.Driver); ok { + qs[`set role "sqlcode-deploy-sandbox-user"`] = nil + qs[`call sqlcode.createcodeschema(@schemasuffix)`] = []interface{}{ + pgx.NamedArgs{"schemasuffix": d.SchemaSuffix}, + } + conn, err := dbc.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + return uploadFunc(conn) + } + return fmt.Errorf("failed to determine sql driver to upload schema: %s", d.SchemaSuffix) } // EnsureUploaded checks that the schema with the suffix already exists, @@ -137,36 +170,51 @@ func (d *Deployable) EnsureUploaded(ctx context.Context, dbc DB) error { return nil } + driver := dbc.Driver() lockResourceName := "sqlcode.EnsureUploaded/" + d.SchemaSuffix + var lockRetCode int + var lockQs string + var unlockQs string + var err error + // When a lock is opened with the Transaction lock owner, // that lock is released when the transaction is committed or rolled back. - var lockRetCode int - err := dbc.QueryRowContext(ctx, ` -declare @retcode int; -exec @retcode = sp_getapplock @Resource = @resource, @LockMode = 'Shared', @LockOwner = 'Session', @LockTimeout = @timeoutMs; -select @retcode; -`, - sql.Named("resource", lockResourceName), - sql.Named("timeoutMs", 20000), - ).Scan(&lockRetCode) + if _, ok := driver.(*pgxstdlib.Driver); ok { + lockQs = `select sqlcode.get_applock(@resource, @timeout)` + unlockQs = `select sqlcode.release_applock(@resource)` + + err = dbc.QueryRowContext(ctx, lockQs, pgx.NamedArgs{ + "resource": lockResourceName, + "timeoutMs": 20000, + }).Scan(&lockRetCode) + + defer func() { + dbc.ExecContext(ctx, unlockQs, pgx.NamedArgs{"resource": lockResourceName}) + }() + } + + if _, ok := driver.(*mssql.Driver); ok { + // TODO + + defer func() { + // TODO: This returns an error if the lock is already released + _, _ = dbc.ExecContext(ctx, unlockQs, + sql.Named("Resource", lockResourceName), + sql.Named("LockOwner", "Session"), + ) + }() + } + if err != nil { return err } if lockRetCode < 0 { return errors.New("was not able to get lock before timeout") } - - defer func() { - _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, - sql.Named("Resource", lockResourceName), - sql.Named("LockOwner", "Session"), - ) - }() - exists, err := Exists(ctx, dbc, d.SchemaSuffix) if err != nil { - return err + return fmt.Errorf("unable to determine if schema %s exists: %w", d.SchemaSuffix, err) } if exists { @@ -195,11 +243,28 @@ func (d Deployable) DropAndUpload(ctx context.Context, dbc DB) error { } // Patch will preprocess the sql passed in so that it will call SQL code -// deployed by the receiver Deployable +// deployed by the receiver Deployable for SQL Server. +// NOTE: This will be deprecated and eventually replaced with CodePatch. func (d Deployable) Patch(sql string) string { return preprocessString(d.SchemaSuffix, sql) } +// CodePatch will preprocess the sql passed in to call +// the correct SQL code deployed to the provided database. +// Q: Nameing? DBPatch, PatchV2, ??? +func (d Deployable) CodePatch(dbc *sql.DB, sql string) string { + driver := dbc.Driver() + if _, ok := driver.(*mssql.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`[code@%s]`, d.SchemaSuffix)) + } + + if _, ok := driver.(*stdlib.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`"code@%s"`, d.SchemaSuffix)) + } + + panic("unhandled sql driver") +} + func (d *Deployable) markAsUploaded(dbc DB) { d.uploaded[dbc] = struct{}{} } @@ -211,9 +276,8 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case - func (d Deployable) IntConst(s string) (int, error) { - for _, declare := range d.CodeBase.Declares { + for _, declare := range d.CodeBase.Declares() { if declare.VariableName == s { // TODO: more robust integer SQL parsing than this; only works // in most common cases @@ -247,8 +311,8 @@ type Options struct { func Include(opts Options, fsys ...fs.FS) (result Deployable, err error) { parsedFiles, doc, err := sqlparser.ParseFilesystems(fsys, opts.IncludeTags) - if len(doc.Errors) > 0 && !opts.PartialParseResults { - return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors} + if doc.HasErrors() && !opts.PartialParseResults { + return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors()} } result.CodeBase = doc @@ -280,10 +344,28 @@ func (s *SchemaObject) Suffix() string { // Return a list of sqlcode schemas that have been uploaded to the database. // This includes all current and unused schemas. -func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { +func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) ([]*SchemaObject, error) { objects := []*SchemaObject{} - impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { - rows, err := conn.QueryContext(ctx, ` + driver := dbc.Driver() + var qs string + + var list = func(conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, qs) + if err != nil { + return err + } + + for rows.Next() { + zero := &SchemaObject{} + rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) + objects = append(objects, zero) + } + + return nil + } + + if _, ok := driver.(*mssql.Driver); ok { + qs = ` select s.name , s.schema_id @@ -298,18 +380,32 @@ func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { from sys.objects o where o.schema_id = s.schema_id ) as o - where s.name like 'code@%'`) + where s.name like 'code@%'` + impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", list) + } + + // TODO(ks) the timestamps for schemas + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select nspname as name + , oid as schema_id + , 0 as objects + , '' as create_date + , '' as modify_date + from pg_namespace + where nspname like 'code@%' + order by nspname` + conn, err := dbc.Conn(ctx) if err != nil { - return err + return nil, err } - - for rows.Next() { - zero := &SchemaObject{} - rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) - objects = append(objects, zero) + err = list(conn) + if err != nil { + return nil, err } + defer func() { + _ = conn.Close() + }() + } - return nil - }) - return objects + return objects, nil } diff --git a/deployable_test.go b/deployable_test.go index 1e9dac5..7b87b57 100644 --- a/deployable_test.go +++ b/deployable_test.go @@ -21,5 +21,11 @@ declare @EnumInt int = 1, @EnumString varchar(max) = 'hello'; n, err := d.IntConst("@EnumInt") require.NoError(t, err) assert.Equal(t, 1, n) +} +func TestPatch(t *testing.T) { + t.Run("mssql schemasuffix", func(t *testing.T) { + d := Deployable{} + require.Equal(t, "[code@].Foo", d.Patch("[code].Foo")) + }) } diff --git a/docker-compose.mssql.yml b/docker-compose.mssql.yml new file mode 100644 index 0000000..f2e4471 --- /dev/null +++ b/docker-compose.mssql.yml @@ -0,0 +1,25 @@ +services: + mssql: + image: mcr.microsoft.com/mssql/server:latest + networks: + - mssql + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: VippsPw1 + healthcheck: + test: ["CMD", "/opt/mssql-tools18/bin/sqlcmd", "-C", "-Usa", "-PVippsPw1", "-Q", "select 1"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - mssql + environment: + SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + GODEBUG: "x509negativeserial=1" + depends_on: + mssql: + condition: service_healthy +networks: + mssql: diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml new file mode 100644 index 0000000..a6d7a0b --- /dev/null +++ b/docker-compose.pgsql.yml @@ -0,0 +1,27 @@ +services: + postgres: + image: postgres + networks: + - postgres + environment: + POSTGRES_PASSWORD: VippsPw1 + POSTGRES_USER: sa + POSTGRES_DB: master + PGOPTIONS: "-c log_error_verbosity=verbose -c log_statement=all" + healthcheck: + test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - postgres + environment: + SQLSERVER_DSN: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable + GODEBUG: "x509negativeserial=1" + depends_on: + postgres: + condition: service_healthy +networks: + postgres: diff --git a/docker-compose.test.yml b/docker-compose.test.yml deleted file mode 100644 index 94da3ec..0000000 --- a/docker-compose.test.yml +++ /dev/null @@ -1,15 +0,0 @@ -services: - # - # mssql - # - mssql: - image: mcr.microsoft.com/mssql/server:latest - - hostname: mssql - container_name: mssql - network_mode: bridge - ports: - - "1433:1433" - environment: - ACCEPT_EULA: "Y" - SA_PASSWORD: VippsPw1 diff --git a/dockerfile.test b/dockerfile.test new file mode 100644 index 0000000..dbd2061 --- /dev/null +++ b/dockerfile.test @@ -0,0 +1,5 @@ +FROM golang:1.25 AS builder +WORKDIR /sqlcode +COPY . . +RUN go mod tidy +CMD ["go", "test", "-v", "./..."] \ No newline at end of file diff --git a/example/basic/example.go b/example/basic/example.go index abe1194..9406915 100644 --- a/example/basic/example.go +++ b/example/basic/example.go @@ -1,3 +1,6 @@ +//go:build examples +// +build examples + package example import ( diff --git a/example/basic/example_test.go b/example/basic/example_test.go index 079bd91..0c78c63 100644 --- a/example/basic/example_test.go +++ b/example/basic/example_test.go @@ -1,13 +1,17 @@ +//go:build examples +// +build examples + package example import ( "context" "fmt" + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqltest" - "testing" - "time" ) func TestPreprocess(t *testing.T) { diff --git a/go.mod b/go.mod index fa5e2c6..a1ad8a1 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,9 @@ go 1.24.3 require ( github.com/alecthomas/repr v0.5.2 - github.com/denisenkom/go-mssqldb v0.12.3 github.com/gofrs/uuid v4.4.0+incompatible + github.com/jackc/pgx/v5 v5.7.6 + github.com/microsoft/go-mssqldb v1.9.5 github.com/sirupsen/logrus v1.9.3 github.com/smasher164/xid v0.1.2 github.com/spf13/cobra v1.10.1 @@ -15,17 +16,26 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.43.0 // indirect + golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/text v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index d0ecd83..6fd729c 100644 --- a/go.sum +++ b/go.sum @@ -1,34 +1,70 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 h1:lhSJz9RMbJcTgxifR1hUNJnn6CNYtbgEDtQV22/9RBA= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 h1:OYa9vmRX2XC5GXRAzeggG12sF/z5D9Ahtdm9EJ00WN4= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 h1:v9p9TfTbf7AwNb5NYQt7hI41IfPoLFiFkLtb+bmGjT0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= -github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0= +github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= +github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smasher164/xid v0.1.2 h1:erplXSdBRIIw+MrwjJ/m8sLN2XY16UGzpTA0E2Ru6HA= @@ -38,40 +74,27 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/migrations/0001.sqlcode.pgsql b/migrations/0001.sqlcode.pgsql new file mode 100644 index 0000000..1395bc4 --- /dev/null +++ b/migrations/0001.sqlcode.pgsql @@ -0,0 +1,229 @@ +-- ====================================================================== +-- create users and roles +-- ====================================================================== +do $$ +begin + -- role that will own the sqlcode schemas (actual code schemas), with no login + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-user-with-no-permissions' + ) then + create role "sqlcode-user-with-no-permissions" nologin; + end if; + + -- role that owns the management schema/procedures (security definer) + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-definer-role' + ) then + create role "sqlcode-definer-role" nologin; + end if; + + -- role that gets execute / usage on code schemas (for humans debugging etc.) + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-execute-role' + ) then + create role "sqlcode-execute-role"; + end if; + + -- role for calling createcodeschema / dropcodeschema; + -- this role does not own the procedures, it only calls them. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-role' + ) then + create role "sqlcode-deploy-role"; + end if; + + -- sandbox role used during deploys, which only has sqlcode-deploy-role + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-sandbox-user' + ) then + create role "sqlcode-deploy-sandbox-user" nologin; + end if; +end; +$$; + +-- ====================================================================== +-- grant permissions / role memberships +-- ====================================================================== + +do $$ +begin + -- grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user" + if not exists ( + select 1 + from pg_auth_members m + join pg_roles r_role on r_role.oid = m.roleid + join pg_roles r_member on r_member.oid = m.member + where r_role.rolname = 'sqlcode-deploy-role' + and r_member.rolname = 'sqlcode-deploy-sandbox-user' + ) then + grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user"; + end if; +end; +$$; + +-- ====================================================================== +-- create schema for management code (owner = definer role) +-- ====================================================================== + +do $$ +begin + if not exists ( + select 1 from pg_namespace where nspname = 'sqlcode' + ) then + create schema sqlcode authorization "sqlcode-definer-role"; + end if; +end; +$$; + +-- ====================================================================== +-- create procedures (security definer) +-- ====================================================================== + +create or replace procedure sqlcode.createcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); +begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + + -- create the schema owned by "sqlcode-user-with-no-permissions" + execute format( + 'create schema if not exists %I authorization %I', + schemaname, + 'sqlcode-user-with-no-permissions' + ); + + -- grant schema privileges + execute format( + 'grant usage on schema %I to %I', + schemaname, + 'sqlcode-execute-role' + ); + + execute format( + 'grant usage, create on schema %I to %I', + schemaname, + 'sqlcode-deploy-role' + ); + +exception + when others then + raise; +end; +$$; + +create or replace procedure sqlcode.dropcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); + schema_exists boolean; +begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + + -- check schema existence + select exists ( + select 1 + from pg_namespace + where nspname = schemaname + ) into schema_exists; + + if not schema_exists then + raise exception 'schema "%" not found', schemaname; + end if; + + -- drop the schema and all objects within it + execute format('drop schema %I cascade', schemaname); + +exception + when others then + raise; +end; +$$; + +-- similar behaviour as mssql getapplock +-- PostgreSQL advisory locks are session-based by default +create or replace function sqlcode.get_applock( + resource text, + timeout_ms integer default 0 +) +returns integer +language plpgsql +as $$ +declare + resource_key bigint; + acquired boolean; + waited_ms integer := 0; +begin + -- convert string to advisory-lock key + select hashtext(resource) into resource_key; + + -- attempt lock with timeout loop + loop + select pg_try_advisory_lock_shared(resource_key) + into acquired; + + if acquired then + return 1; -- lock acquired (success) + end if; + + if waited_ms >= timeout_ms then + return 0; -- timeout + end if; + + perform pg_sleep(0.01); -- sleep 10 ms + waited_ms := waited_ms + 10; + end loop; + + return null; -- safety fallback (should never hit) +end; +$$; + +create or replace function sqlcode.release_applock(resource text) +returns boolean +language sql +as $$ + select pg_advisory_unlock_shared(hashtext(resource)); +$$; + + +-- ensure procedures are owned by the definer role +alter procedure sqlcode.createcodeschema(varchar) + owner to "sqlcode-definer-role"; + +alter procedure sqlcode.dropcodeschema(varchar) + owner to "sqlcode-definer-role"; + +-- ====================================================================== +-- privileges on the procedures and base schema +-- ====================================================================== + +-- allow deploy role to call the management procedures +grant execute on procedure sqlcode.createcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant execute on procedure sqlcode.dropcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant "sqlcode-user-with-no-permissions" + to "sqlcode-definer-role"; + +-- usually deploy role does not need create in the sqlcode management schema +-- (the procedures handle creation in separate "code@..." schemas) +grant usage on schema sqlcode + to "sqlcode-deploy-role"; diff --git a/error.go b/mssql_error.go similarity index 88% rename from error.go rename to mssql_error.go index 6131fbf..d6f531e 100644 --- a/error.go +++ b/mssql_error.go @@ -3,17 +3,18 @@ package sqlcode import ( "bytes" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/vippsas/sqlcode/sqlparser" "strings" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/vippsas/sqlcode/sqlparser" ) -type SQLUserError struct { +type MSSQLUserError struct { Wrapped mssql.Error Batch Batch } -func (s SQLUserError) Error() string { +func (s MSSQLUserError) Error() string { var buf bytes.Buffer if _, fmterr := fmt.Fprintf(&buf, "\n"); fmterr != nil { diff --git a/preprocess.go b/preprocess.go index 5a8adab..9478c3f 100644 --- a/preprocess.go +++ b/preprocess.go @@ -2,20 +2,24 @@ package sqlcode import ( "crypto/sha256" + "database/sql/driver" "encoding/hex" "errors" "fmt" - "github.com/vippsas/sqlcode/sqlparser" + "reflect" "regexp" "strings" + + "github.com/jackc/pgx/v5/stdlib" + "github.com/vippsas/sqlcode/sqlparser" ) func SchemaSuffixFromHash(doc sqlparser.Document) string { hasher := sha256.New() - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { hasher.Write([]byte(dec.String() + "\n")) } - for _, c := range doc.Creates { + for _, c := range doc.Creates() { if err := c.SerializeBytes(hasher); err != nil { panic(err) // asserting that sha256 will never return a write error... } @@ -128,7 +132,6 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot result.lineNumberCorrections = append(result.lineNumberCorrections, lineNumberCorrection{relativeLine, newlineCount}) } } - if _, err = w.WriteString(token); err != nil { return } @@ -138,7 +141,7 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot return } -func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, error) { +func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { var result PreprocessedFile if strings.Contains(schemasuffix, "]") { @@ -146,17 +149,32 @@ func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, } declares := make(map[string]string) - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { declares[dec.VariableName] = dec.Literal.RawValue } - for _, create := range doc.Creates { + // The current sql driver that we are preparring for + currentDriver := reflect.TypeOf(driver) + + // the default target for mssql + target := fmt.Sprintf(`[code@%s]`, schemasuffix) + + // pgsql target + if _, ok := driver.(*stdlib.Driver); ok { + target = fmt.Sprintf(`"code@%s"`, schemasuffix) + } + + for _, create := range doc.Creates() { if len(create.Body) == 0 { continue } - batch, err := sqlcodeTransformCreate(declares, create, "[code@"+schemasuffix+"]") + if !currentDriver.AssignableTo(reflect.TypeOf(create.Driver)) { + // this batch is for a different sql driver + continue + } + batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { - return result, err + return result, fmt.Errorf("failed to transform create: %w", err) } result.Batches = append(result.Batches, batch) } diff --git a/preprocess_test.go b/preprocess_test.go index bf976e8..85132e6 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -1,24 +1,16 @@ package sqlcode import ( + "strings" "testing" + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqlparser" ) -func TestSchemaSuffixFromHash(t *testing.T) { - t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - - value := SchemaSuffixFromHash(doc) - require.Equal(t, value, SchemaSuffixFromHash(doc)) - }) -} - func TestLineNumberInInput(t *testing.T) { // Scenario: @@ -63,3 +55,340 @@ func TestLineNumberInInput(t *testing.T) { } assert.Equal(t, expectedInputLineNumbers, inputlines[1:]) } + +func TestSchemaSuffixFromHash(t *testing.T) { + t.Run("returns a unique hash", func(t *testing.T) { + doc := sqlparser.NewDocumentFromExtension(".sql") + value := SchemaSuffixFromHash(doc) + require.Equal(t, value, SchemaSuffixFromHash(doc)) + }) + + t.Run("returns consistent hash", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc) + suffix2 := SchemaSuffixFromHash(doc) + + assert.Equal(t, suffix1, suffix2) + assert.Len(t, suffix1, 12) // 6 bytes = 12 hex chars + }) + + t.Run("different content yields different hash", func(t *testing.T) { + doc1 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test1 as begin end +`) + doc2 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 2; +go +create procedure [code].Test2 as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc1) + suffix2 := SchemaSuffixFromHash(doc2) + + assert.NotEqual(t, suffix1, suffix2) + }) + + t.Run("empty document has hash", func(t *testing.T) { + doc := sqlparser.NewDocumentFromExtension(".pgsql") + suffix := SchemaSuffixFromHash(doc) + assert.Len(t, suffix, 12) + }) +} + +func TestSchemaName(t *testing.T) { + assert.Equal(t, "code@abc123", SchemaName("abc123")) + assert.Equal(t, "code@", SchemaName("")) +} + +func TestBatchLineNumberInInput(t *testing.T) { + t.Run("no corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nline3", + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) + assert.Equal(t, 11, b.LineNumberInInput(2)) + assert.Equal(t, 12, b.LineNumberInInput(3)) + }) + + t.Run("with corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nextra1\nextra2\nline3", + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 2}, // line 2 became 3 lines + }, + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) // line 1 -> input line 10 + assert.Equal(t, 11, b.LineNumberInInput(2)) // line 2 -> input line 11 + assert.Equal(t, 11, b.LineNumberInInput(3)) // extra line -> still input line 11 + assert.Equal(t, 11, b.LineNumberInInput(4)) // extra line -> still input line 11 + assert.Equal(t, 12, b.LineNumberInInput(5)) // line 3 -> input line 12 + }) +} + +func TestBatchRelativeLineNumberInInput(t *testing.T) { + t.Run("simple case with no corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{}, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(5)) + }) + + t.Run("single correction", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 3, extraLinesInOutput: 2}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(3)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) // extra line + assert.Equal(t, 3, b.RelativeLineNumberInInput(5)) // extra line + assert.Equal(t, 4, b.RelativeLineNumberInInput(6)) + }) + + t.Run("multiple corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 1}, + {inputLineNumber: 5, extraLinesInOutput: 3}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(3)) // extra from line 2 + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) + assert.Equal(t, 4, b.RelativeLineNumberInInput(5)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(6)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(7)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(8)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(9)) // extra from line 5 + assert.Equal(t, 6, b.RelativeLineNumberInInput(10)) + }) +} + +func TestPreprocess(t *testing.T) { + t.Run("basic procedure with schema replacement", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select 1 +end +`) + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "[code@abc123].") + assert.NotContains(t, result.Batches[0].Lines, "[code].") + }) + + t.Run("postgres uses unquoted schema name", func(t *testing.T) { + doc := sqlparser.ParseString("test.pgsql", ` +create procedure [code].test() as $$ +begin + perform 1; +end; +$$ language plpgsql; +`) + result, err := Preprocess(doc, "abc123", &stdlib.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, `"code@abc123".`) + assert.NotContains(t, result.Batches[0].Lines, "[code@abc123].") + }) + + t.Run("replaces enum constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumStatus int = 42; +go +create procedure [code].Test as +begin + select @EnumStatus +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "42/*=@EnumStatus*/") + assert.NotContains(t, batch, "@EnumStatus\n") // shouldn't have bare reference + }) + + t.Run("handles multiline string constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumMulti nvarchar(max) = N'line1 +line2 +line3'; +go +create procedure [code].Test as +begin + select @EnumMulti +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0] + assert.Contains(t, batch.Lines, "N'line1\nline2\nline3'/*=@EnumMulti*/") + // Should have line number corrections for the 2 extra lines + assert.Len(t, batch.lineNumberCorrections, 1) + assert.Equal(t, 2, batch.lineNumberCorrections[0].extraLinesInOutput) + }) + + t.Run("error on undeclared constant", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select @EnumUndeclared +end +`) + _, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.Error(t, err) + + var preprocErr PreprocessorError + require.ErrorAs(t, err, &preprocErr) + assert.Contains(t, preprocErr.Message, "@EnumUndeclared") + assert.Contains(t, preprocErr.Message, "not declared") + }) + + t.Run("error on schema suffix with bracket", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as begin end +`) + _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "schemasuffix cannot contain") + }) + + t.Run("handles multiple creates", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Proc1 as begin select 1 end +go +create procedure [code].Proc2 as begin select 2 end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Len(t, result.Batches, 2) + + assert.Contains(t, result.Batches[0].Lines, "Proc1") + assert.Contains(t, result.Batches[1].Lines, "Proc2") + }) + + t.Run("handles multiple constants in same procedure", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumA int = 1, @EnumB int = 2; +go +create procedure [code].Test as +begin + select @EnumA, @EnumB +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "1/*=@EnumA*/") + assert.Contains(t, batch, "2/*=@EnumB*/") + }) + + t.Run("preserves comments and formatting", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +-- This is a test procedure +create procedure [code].Test as +begin + /* multi + line + comment */ + select 1 +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "-- This is a test procedure") + assert.Contains(t, batch, "/* multi") + }) + + t.Run("handles const and global prefixes", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @ConstValue int = 100; +declare @GlobalSetting nvarchar(50) = N'test'; +go +create procedure [code].Test as +begin + select @ConstValue, @GlobalSetting +end +`) + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "100/*=@ConstValue*/") + assert.NotContains(t, batch, "N'test'/*=@GlobalSetting*/") + }) +} + +func TestPreprocessString(t *testing.T) { + t.Run("replaces code schema", func(t *testing.T) { + result := preprocessString("abc123", "select * from [code].Table") + assert.Equal(t, "select * from [code@abc123].Table", result) + }) + + t.Run("case insensitive replacement", func(t *testing.T) { + result := preprocessString("abc123", "select * from [CODE].Table and [Code].Other") + assert.Contains(t, result, "[code@abc123].Table") + assert.Contains(t, result, "[code@abc123].Other") + }) + + t.Run("multiple occurrences", func(t *testing.T) { + sql := ` + select * from [code].A + join [code].B on A.id = B.id + where exists (select 1 from [code].C) + ` + result := preprocessString("abc123", sql) + assert.Equal(t, 3, strings.Count(result, "[code@abc123]")) + assert.NotContains(t, result, "[code].") + }) + + t.Run("no replacement needed", func(t *testing.T) { + sql := "select * from dbo.Table" + result := preprocessString("abc123", sql) + assert.Equal(t, sql, result) + }) +} + +func TestPreprocessorError(t *testing.T) { + t.Run("formats error message", func(t *testing.T) { + err := PreprocessorError{ + Pos: sqlparser.Pos{File: "test.sql", Line: 10, Col: 5}, + Message: "something went wrong", + } + + assert.Equal(t, "test.sql:10:5: something went wrong", err.Error()) + }) +} diff --git a/sqlcode.yaml b/sqlcode.yaml index 549c23f..8adefb7 100644 --- a/sqlcode.yaml +++ b/sqlcode.yaml @@ -1,6 +1,8 @@ databases: - localtest: - connection: sqlserver://localhost:1433?database=foo&user id=foouser&password=FooPasswd1 + mssql: + connection: sqlserver://mssql:1433?database=foo&user id=foouser&password=FooPasswd1 + pgsql: + connection: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable # One option is to list other paths to include ('dependencies') here. diff --git a/sqlparser/batch.go b/sqlparser/batch.go new file mode 100644 index 0000000..25504f5 --- /dev/null +++ b/sqlparser/batch.go @@ -0,0 +1,91 @@ +package sqlparser + +import ( + "fmt" +) + +type Batch struct { + Nodes []Unparsed + DocString []PosString + CreateStatements int + TokenHandlers map[string]func(*Scanner, *Batch) bool + Errors []Error + BatchSeparatorToken TokenType +} + +func (n *Batch) Create(s *Scanner) { + n.Nodes = append(n.Nodes, CreateUnparsed(s)) +} + +func (n *Batch) HasErrors() bool { + return len(n.Errors) > 0 +} + +// Agnostic parser that handles comments, whitespace, and reserved words +func (n *Batch) Parse(s *Scanner) bool { + newLineEncounteredInDocstring := false + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + n.Create(s) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + n.DocString = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + n.DocString = append(n.DocString, PosString{s.Start(), s.Token()}) + n.Create(s) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + token := s.ReservedWord() + handler, exists := n.TokenHandlers[token] + if !exists { + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Expected , got: %s", token), + }) + s.NextToken() + } else { + if handler(s, n) { + // regardless of errors, go on and parse as far as we get... + return true + } + } + case BatchSeparatorToken: + // TODO + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + n.Errors = append(n.Errors, Error{ + s.Start(), "`go` should be alone on a line without any comments", + }) + errorEmitted = true + } + continue + default: + return true + } + } + default: + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Unexpected token: %s", s.Token()), + }) + s.NextToken() + n.DocString = nil + } + } +} diff --git a/sqlparser/create.go b/sqlparser/create.go new file mode 100644 index 0000000..0bdbbf2 --- /dev/null +++ b/sqlparser/create.go @@ -0,0 +1,98 @@ +package sqlparser + +import ( + "database/sql/driver" + "io" + "strings" + + "gopkg.in/yaml.v3" +) + +type Create struct { + CreateType string // "procedure", "function" or "type" + QuotedName PosString // proc/func/type name, including [] + Body []Unparsed + DependsOn []PosString + Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Driver driver.Driver // the sql driver this document is intended for +} + +func (c Create) DocstringAsString() string { + var result []string + for _, line := range c.Docstring { + result = append(result, line.Value) + } + return strings.Join(result, "\n") +} + +func (c Create) DocstringYamldoc() (string, error) { + var yamldoc []string + parsing := false + for _, line := range c.Docstring { + if strings.HasPrefix(line.Value, "--!") { + parsing = true + if !strings.HasPrefix(line.Value, "--! ") { + return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} + } + yamldoc = append(yamldoc, line.Value[4:]) + } else if parsing { + return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} + } + } + return strings.Join(yamldoc, "\n"), nil +} + +func (c Create) ParseYamlInDocstring(out any) error { + yamldoc, err := c.DocstringYamldoc() + if err != nil { + return err + } + return yaml.Unmarshal([]byte(yamldoc), out) +} + +func (c Create) Serialize(w io.StringWriter) error { + for _, l := range c.Body { + if _, err := w.WriteString(l.RawValue); err != nil { + return err + } + } + return nil +} + +func (c Create) SerializeBytes(w io.Writer) error { + for _, l := range c.Body { + if _, err := w.Write([]byte(l.RawValue)); err != nil { + return err + } + } + return nil +} + +func (c Create) String() string { + var buf strings.Builder + err := c.Serialize(&buf) + if err != nil { + panic(err) + } + return buf.String() +} + +func (c Create) WithoutPos() Create { + var body []Unparsed + for _, x := range c.Body { + body = append(body, x.WithoutPos()) + } + return Create{ + CreateType: c.CreateType, + QuotedName: c.QuotedName, + DependsOn: c.DependsOn, + Body: body, + } +} + +func (c Create) DependsOnStrings() (result []string) { + for _, x := range c.DependsOn { + result = append(result, x.Value) + } + return +} diff --git a/sqlparser/document.go b/sqlparser/document.go new file mode 100644 index 0000000..1721875 --- /dev/null +++ b/sqlparser/document.go @@ -0,0 +1,135 @@ +package sqlparser + +import ( + "fmt" + "path/filepath" + "slices" + "strings" +) + +// Document represents a parsed SQL document, containing +// declarations, create statements, pragmas, and errors. +// It provides methods to access and manipulate these components +// for T-SQL and PostgreSQL +type Document interface { + Empty() bool + HasErrors() bool + + Creates() []Create + Declares() []Declare + Errors() []Error + PragmaIncludeIf() []string + Include(other Document) + Sort() + Parse(s *Scanner) error + WithoutPos() Document +} + +// Helper function to parse a SQL document from a string input +func ParseString(filename FileRef, input string) (result Document) { + result = NewDocumentFromExtension(filepath.Ext(strings.ToLower(string(filename)))) + Parse(&Scanner{input: input, file: filename}, result) + return +} + +// Based on the input file extension, create the appropriate Document type +func NewDocumentFromExtension(extension string) Document { + switch extension { + case ".sql": + return &TSqlDocument{} + case ".pgsql": + return &PGSqlDocument{} + default: + panic("unhandled document type: " + extension) + } +} + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func ParseCodeschemaName(s *Scanner, target *[]Unparsed, statementTokens []string) (PosString, error) { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code] must be followed by '.'") + } + + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result, nil + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result, nil + default: + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code]. must be followed an identifier") + } +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} + +func RecoverToNextStatementCopying(s *Scanner, target *[]Unparsed, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func RecoverToNextStatement(s *Scanner, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + } + } +} diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go new file mode 100644 index 0000000..ec011c5 --- /dev/null +++ b/sqlparser/document_test.go @@ -0,0 +1,151 @@ +package sqlparser + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewDocumentFromExtension(t *testing.T) { + t.Run("returns TSqlDocument for .sql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".sql") + + _, ok := doc.(*TSqlDocument) + assert.True(t, ok, "Expected TSqlDocument type") + assert.NotNil(t, doc) + }) + + t.Run("returns PGSqlDocument for .pgsql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".pgsql") + + _, ok := doc.(*PGSqlDocument) + assert.True(t, ok, "Expected PGSqlDocument type") + assert.NotNil(t, doc) + }) + + t.Run("panics for unsupported extension", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".txt") + }, "Expected panic for unsupported extension") + }) + + t.Run("panics for empty extension", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension("") + }, "Expected panic for empty extension") + }) + + t.Run("panics for unknown SQL extension", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".mysql") + }, "Expected panic for .mysql extension") + }) + + t.Run("extension matching is case insensitive", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".SQL") + }, "Expected panic for uppercase .SQL") + }) + + t.Run("returned documents implement Document interface", func(t *testing.T) { + sqlDoc := NewDocumentFromExtension(".sql") + pgsqlDoc := NewDocumentFromExtension(".pgsql") + require.NotEqual(t, sqlDoc, pgsqlDoc) + }) +} + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.Error(t, err) + + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + + assert.Error(t, err) + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed an identifier") + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + RecoverToNextStatement(s, []string{"declare"}) + + fmt.Printf("%#v\n", s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + RecoverToNextStatement(s, []string{"create"}) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + RecoverToNextStatement(s, []string{}) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + RecoverToNextStatementCopying(s, &target, []string{"declare"}) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} diff --git a/sqlparser/dom.go b/sqlparser/dom.go index cc661f4..a75db38 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -2,26 +2,9 @@ package sqlparser import ( "fmt" - "gopkg.in/yaml.v3" - "io" "strings" ) -type Unparsed struct { - Type TokenType - Start, Stop Pos - RawValue string -} - -func (u Unparsed) WithoutPos() Unparsed { - return Unparsed{ - Type: u.Type, - Start: Pos{}, - Stop: Pos{}, - RawValue: u.RawValue, - } -} - type Declare struct { Start Pos Stop Pos @@ -59,47 +42,6 @@ func (p PosString) String() string { return p.Value } -type Create struct { - CreateType string // "procedure", "function" or "type" - QuotedName PosString // proc/func/type name, including [] - Body []Unparsed - DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body -} - -func (c Create) DocstringAsString() string { - var result []string - for _, line := range c.Docstring { - result = append(result, line.Value) - } - return strings.Join(result, "\n") -} - -func (c Create) DocstringYamldoc() (string, error) { - var yamldoc []string - parsing := false - for _, line := range c.Docstring { - if strings.HasPrefix(line.Value, "--!") { - parsing = true - if !strings.HasPrefix(line.Value, "--! ") { - return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} - } - yamldoc = append(yamldoc, line.Value[4:]) - } else if parsing { - return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} - } - } - return strings.Join(yamldoc, "\n"), nil -} - -func (c Create) ParseYamlInDocstring(out any) error { - yamldoc, err := c.DocstringYamldoc() - if err != nil { - return err - } - return yaml.Unmarshal([]byte(yamldoc), out) -} - type Type struct { BaseType string Args []string @@ -125,115 +67,3 @@ func (e Error) Error() string { func (e Error) WithoutPos() Error { return Error{Message: e.Message} } - -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - -func (c Create) Serialize(w io.StringWriter) error { - for _, l := range c.Body { - if _, err := w.WriteString(l.RawValue); err != nil { - return err - } - } - return nil -} - -func (c Create) SerializeBytes(w io.Writer) error { - for _, l := range c.Body { - if _, err := w.Write([]byte(l.RawValue)); err != nil { - return err - } - } - return nil -} - -func (c Create) String() string { - var buf strings.Builder - err := c.Serialize(&buf) - if err != nil { - panic(err) - } - return buf.String() -} - -func (c Create) WithoutPos() Create { - var body []Unparsed - for _, x := range c.Body { - body = append(body, x.WithoutPos()) - } - return Create{ - CreateType: c.CreateType, - QuotedName: c.QuotedName, - DependsOn: c.DependsOn, - Body: body, - } -} - -func (c Create) DependsOnStrings() (result []string) { - for _, x := range c.DependsOn { - result = append(result, x.Value) - } - return -} - -// Transform a Document to remove all Position information; this is used -// to 'unclutter' a DOM to more easily write assertions on it. -func (d Document) WithoutPos() Document { - var cs []Create - for _, x := range d.Creates { - cs = append(cs, x.WithoutPos()) - } - var ds []Declare - for _, x := range d.Declares { - ds = append(ds, x.WithoutPos()) - } - var es []Error - for _, x := range d.Errors { - es = append(es, x.WithoutPos()) - } - return Document{ - Creates: cs, - Declares: ds, - Errors: es, - } -} - -func (d *Document) Include(other Document) { - // Do not copy PragmaIncludeIf, since that is local to a single file. - // Its contents is also present in each Create. - d.Declares = append(d.Declares, other.Declares...) - d.Creates = append(d.Creates, other.Creates...) - d.Errors = append(d.Errors, other.Errors...) -} - -func (d *Document) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.PragmaIncludeIf = append(d.PragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *Document) parsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() - } -} - -func (d Document) Empty() bool { - return len(d.Creates) > 0 || len(d.Declares) > 0 -} diff --git a/sqlparser/node_test.go b/sqlparser/node_test.go new file mode 100644 index 0000000..04173c6 --- /dev/null +++ b/sqlparser/node_test.go @@ -0,0 +1 @@ +package sqlparser diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 40eebe9..c15e25f 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -9,38 +9,20 @@ import ( "errors" "fmt" "io/fs" + "path/filepath" "regexp" - "sort" + "slices" "strings" ) var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" +var supportedSqlExtensions []string = []string{".sql", ".pgsql"} + func CopyToken(s *Scanner, target *[]Unparsed) { *target = append(*target, CreateUnparsed(s)) } -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - // AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return // Note: The 'go' and EOF tokens are *not* copied func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { @@ -62,466 +44,7 @@ func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { } } -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } - return -} - -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} - -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result - default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } -} - -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - -func Parse(s *Scanner, result *Document) { +func Parse(s *Scanner, result Document) { // Top-level parse; this focuses on splitting into "batches" separated // by 'go'. @@ -538,24 +61,16 @@ func Parse(s *Scanner, result *Document) { // // `s` will typically never be positioned on whitespace except in // whitespace-preserving parsing - s.NextNonWhitespaceToken() - result.parsePragmas(s) - hasMore := result.parseBatch(s, true) - for hasMore { - hasMore = result.parseBatch(s, false) + err := result.Parse(s) + if err != nil { + panic(fmt.Sprintf("failed to parse document: %s: %e", s.file, err)) } return } -func ParseString(filename FileRef, input string) (result Document) { - Parse(&Scanner{input: input, file: filename}, &result) - return -} - -// ParseFileystems iterates through a list of filesystems and parses all files -// matching `*.sql`, determines which one are sqlcode files from the contents, -// and returns the combination of all of them. +// ParseFileystems iterates through a list of filesystems and parses all supported +// SQL files and returns the combination of all of them. // // err will only return errors related to filesystems/reading. Errors // related to parsing/sorting will be in result.Errors. @@ -580,7 +95,9 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") { + + extension := filepath.Ext(path) + if !slices.Contains(supportedSqlExtensions, extension) { return nil } @@ -605,10 +122,10 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } hashes[hash] = pathDesc - var fdoc Document - Parse(&Scanner{input: string(buf), file: FileRef(path)}, &fdoc) + fdoc := NewDocumentFromExtension(extension) + Parse(&Scanner{input: string(buf), file: FileRef(path)}, fdoc) - if matchesIncludeTags(fdoc.PragmaIncludeIf, includeTags) { + if matchesIncludeTags(fdoc.PragmaIncludeIf(), includeTags) { filenames = append(filenames, pathDesc) result.Include(fdoc) } @@ -620,17 +137,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } } - // Do the topological sort; and include any error with it as part - // of `result`, *not* return it as err - sortedCreates, errpos, sortErr := TopologicalSort(result.Creates) - if sortErr != nil { - result.Errors = append(result.Errors, Error{ - Pos: errpos, - Message: sortErr.Error(), - }) - } else { - result.Creates = sortedCreates - } + result.Sort() return } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 7bd20b8..39f1e58 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -2,13 +2,32 @@ package sqlparser import ( "fmt" + "io/fs" "strings" "testing" + "testing/fstest" + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestPostgresqlCreate(t *testing.T) { + doc := ParseString("test.pgsql", ` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; + `) + + require.Len(t, doc.Creates(), 1) + require.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) +} + func TestParserSmokeTest(t *testing.T) { doc := ParseString("test.sql", ` /* test is a test @@ -41,8 +60,9 @@ end; `) docNoPos := doc.WithoutPos() - require.Equal(t, 1, len(doc.Creates)) - c := doc.Creates[0] + require.Equal(t, 1, len(doc.Creates())) + c := doc.Creates()[0] + require.Equal(t, &mssql.Driver{}, c.Driver) assert.Equal(t, "[TestFunc]", c.QuotedName.Value) assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) @@ -62,7 +82,7 @@ end; { Message: "'declare' statement only allowed in first batch", }, - }, docNoPos.Errors) + }, docNoPos.Errors()) assert.Equal(t, []Declare{ @@ -110,7 +130,7 @@ end; }, }, }, - docNoPos.Declares, + docNoPos.Declares(), ) // repr.Println(doc) } @@ -130,21 +150,21 @@ create function [code].Two(); { Message: "a procedure/function must be alone in a batch; use 'go' to split batches", }, - }, doc.Errors) + }, doc.Errors()) } func TestBuggyDeclare(t *testing.T) { // this caused parses to infinitely loop; regression test... doc := ParseString("test.sql", `declare @EnumA int = 4 @EnumB tinyint = 5 @ENUM_C bigint = 435;`) - assert.Equal(t, 1, len(doc.Errors)) - assert.Equal(t, "Unexpected: @EnumB", doc.Errors[0].Message) + assert.Equal(t, 1, len(doc.Errors())) + assert.Equal(t, "Unexpected: @EnumB", doc.Errors()[0].Message) } func TestCreateType(t *testing.T) { doc := ParseString("test.sql", `create type [code].MyType as table (x int not null primary key);`) - assert.Equal(t, 1, len(doc.Creates)) - assert.Equal(t, "type", doc.Creates[0].CreateType) - assert.Equal(t, "[MyType]", doc.Creates[0].QuotedName.Value) + assert.Equal(t, 1, len(doc.Creates())) + assert.Equal(t, "type", doc.Creates()[0].CreateType) + assert.Equal(t, "[MyType]", doc.Creates()[0].QuotedName.Value) } func TestPragma(t *testing.T) { @@ -159,7 +179,7 @@ create procedure [code].ProcedureShouldAlsoHavePragmasAnnotated() func TestInfiniteLoopRegression(t *testing.T) { // success if we terminate!... doc := ParseString("test.sql", `@declare`) - assert.Equal(t, 1, len(doc.Errors)) + assert.Equal(t, 1, len(doc.Errors())) } func TestDeclareSeparation(t *testing.T) { @@ -170,7 +190,7 @@ func TestDeclareSeparation(t *testing.T) { doc := ParseString("test.sql", ` declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird int=3 declare @EnumFourth int=4;declare @EnumFifth int =5 `) - //repr.Println(doc.Declares) + //repr.Println(doc.Declares()) require.Equal(t, []Declare{ { VariableName: "@EnumFirst", @@ -197,7 +217,7 @@ declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird Datatype: Type{BaseType: "int"}, Literal: Unparsed{Type: NumberToken, RawValue: "5"}, }, - }, doc.WithoutPos().Declares) + }, doc.WithoutPos().Declares()) } func TestBatchDivisionsAndCreateStatements(t *testing.T) { @@ -212,7 +232,7 @@ go create type [code].Batch3 as table (x int); `) commentCount := 0 - for _, c := range doc.Creates { + for _, c := range doc.Creates() { for _, b := range c.Body { if strings.Contains(b.RawValue, "2nd") { commentCount++ @@ -231,13 +251,13 @@ create type [code].Type1 as table (x int); create type [code].Type2 as table (x int); create type [code].Type3 as table (x int); `) - require.Equal(t, 3, len(doc.Creates)) - assert.Equal(t, "[Type1]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Type3]", doc.Creates[2].QuotedName.Value) + require.Equal(t, 3, len(doc.Creates())) + assert.Equal(t, "[Type1]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Type3]", doc.Creates()[2].QuotedName.Value) // There was a bug that the last item in the body would be the 'create' // of the next statement; regression test.. - assert.Equal(t, "\n", doc.Creates[0].Body[len(doc.Creates[0].Body)-1].RawValue) - assert.Equal(t, "create", doc.Creates[1].Body[0].RawValue) + assert.Equal(t, "\n", doc.Creates()[0].Body[len(doc.Creates()[0].Body)-1].RawValue) + assert.Equal(t, "create", doc.Creates()[1].Body[0].RawValue) } func TestCreateProcs(t *testing.T) { @@ -250,10 +270,10 @@ create type [code].MyType () create procedure [code].MyProcedure () `) // First function and last procedure triggers errors. - require.Equal(t, 2, len(doc.Errors)) + require.Equal(t, 2, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) - assert.Equal(t, emsg, doc.Errors[1].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) + assert.Equal(t, emsg, doc.Errors()[1].Message) } @@ -263,14 +283,14 @@ func TestCreateProcs2(t *testing.T) { create type [code].MyType () create procedure [code].FirstProc as table (x int) `) - //repr.Println(doc.Errors) + //repr.Println(doc.Errors()) // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 1, len(doc.Errors)) + require.Equal(t, 1, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) } func TestCreateProcsAndCheckForRoutineName(t *testing.T) { @@ -302,12 +322,12 @@ create procedure [code].[transform:safeguarding.Calculation/HEAD](@now datetime2 }, } for _, tc := range testcases { - require.Equal(t, 0, len(tc.doc.Errors)) - assert.Len(t, tc.doc.Creates, 1) - assert.Greater(t, len(tc.doc.Creates[0].Body), tc.expectedIndex) + require.Equal(t, 0, len(tc.doc.Errors())) + assert.Len(t, tc.doc.Creates(), 1) + assert.Greater(t, len(tc.doc.Creates()[0].Body), tc.expectedIndex) assert.Equal(t, fmt.Sprintf(templateRoutineName, tc.expectedProcName), - tc.doc.Creates[0].Body[tc.expectedIndex].RawValue, + tc.doc.Creates()[0].Body[tc.expectedIndex].RawValue, ) } } @@ -322,9 +342,9 @@ end // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 2, len(doc.Errors)) - assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors[0].Message) - assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors[1].Message) + require.Equal(t, 2, len(doc.Errors())) + assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors()[0].Message) + assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors()[1].Message) } func TestCreateAnnotationHappyDay(t *testing.T) { @@ -342,8 +362,8 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- This is part of annotation\n--! key1: a\n--! key2: b\n--! key3: [1,2,3]", - doc.Creates[0].DocstringAsString()) - s, err := doc.Creates[0].DocstringYamldoc() + doc.Creates()[0].DocstringAsString()) + s, err := doc.Creates()[0].DocstringYamldoc() assert.NoError(t, err) assert.Equal(t, "key1: a\nkey2: b\nkey3: [1,2,3]", @@ -352,7 +372,7 @@ create procedure [code].Foo as begin end var x struct { Key1 string `yaml:"key1"` } - require.NoError(t, doc.Creates[0].ParseYamlInDocstring(&x)) + require.NoError(t, doc.Creates()[0].ParseYamlInDocstring(&x)) assert.Equal(t, "a", x.Key1) } @@ -367,7 +387,7 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- docstring here", - doc.Creates[0].DocstringAsString()) + doc.Creates()[0].DocstringAsString()) } func TestCreateAnnotationErrors(t *testing.T) { @@ -377,7 +397,7 @@ func TestCreateAnnotationErrors(t *testing.T) { -- This comment after yamldoc is illegal; this also prevents multiple embedded YAML documents create procedure [code].Foo as begin end `) - _, err := doc.Creates[0].DocstringYamldoc() + _, err := doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement", err.Error()) @@ -387,8 +407,251 @@ create procedure [code].Foo as begin end --!key4: 1 create procedure [code].Foo as begin end `) - _, err = doc.Creates[0].DocstringYamldoc() + _, err = doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 YAML document in docstring; missing space after `--!`", err.Error()) } + +func TestParseFilesystems(t *testing.T) { + t.Run("basic parsing of sql files", func(t *testing.T) { + fsys := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(` +declare @EnumFoo int = 1; +go +create procedure [code].Proc1 as begin end +`), + }, + "test2.sql": &fstest.MapFile{ + Data: []byte(` +create function [code].Func1() returns int as begin return 1 end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates(), 2) + assert.Len(t, doc.Declares(), 1) + }) + + t.Run("filters by include tags", func(t *testing.T) { + fsys := fstest.MapFS{ + "included.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo,bar +create procedure [code].Included as begin end +`), + }, + "excluded.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if baz +create procedure [code].Excluded as begin end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo", "bar"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "included.sql") + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, "[Included]", doc.Creates()[0].QuotedName.Value) + }) + + t.Run("detects duplicate files with same hash", func(t *testing.T) { + contents := []byte(`create procedure [code].Test as begin end`) + + fs1 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + fs2 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + + _, _, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "exact same contents") + }) + + t.Run("skips non-sqlcode files", func(t *testing.T) { + fsys := fstest.MapFS{ + "regular.sql": &fstest.MapFile{ + Data: []byte(`select * from table1`), + }, + "sqlcode.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "sqlcode.sql") + assert.Len(t, doc.Creates(), 1) + }) + + t.Run("skips hidden directories", func(t *testing.T) { + fsys := fstest.MapFS{ + "visible.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Visible as begin end`), + }, + ".hidden/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Hidden as begin end`), + }, + "dir/.git/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Git as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "visible.sql") + assert.Len(t, doc.Creates(), 1) + }) + + t.Run("handles dependencies and topological sort", func(t *testing.T) { + fsys := fstest.MapFS{ + "proc1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin exec [code].Proc2 end`), + }, + "proc2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin select 1 end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates(), 2) + // Proc2 should come before Proc1 due to dependency + assert.Equal(t, "[Proc2]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates()[1].QuotedName.Value) + }) + + t.Run("reports topological sort errors", func(t *testing.T) { + fsys := fstest.MapFS{ + "circular1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].A as begin exec [code].B end`), + }, + "circular2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].B as begin exec [code].A end`), + }, + } + + _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors()) // but parsing errors should exist + assert.Contains(t, doc.Errors()[0].Message, "Detected a dependency cycle") + }) + + t.Run("handles multiple filesystems", func(t *testing.T) { + fs1 := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin end`), + }, + } + fs2 := fstest.MapFS{ + "test2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Contains(t, filenames[0], "fs[0]:") + assert.Contains(t, filenames[1], "fs[1]:") + assert.Len(t, doc.Creates(), 2) + }) + + t.Run("detects sqlcode files by pragma header", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo +create procedure NotInCodeSchema.Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + // Should still parse even though it will have errors (not in [code] schema) + assert.NotEmpty(t, doc.Errors()) + }) + + t.Run("handles pgsql extension", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.pgsql": &fstest.MapFile{ + Data: []byte(` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) + }) + + t.Run("empty filesystem returns empty results", func(t *testing.T) { + fsys := fstest.MapFS{} + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Empty(t, filenames) + assert.Empty(t, doc.Creates()) + assert.Empty(t, doc.Declares()) + }) +} + +func TestMatchesIncludeTags(t *testing.T) { + t.Run("empty requirements matches anything", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{}, []string{})) + assert.True(t, matchesIncludeTags([]string{}, []string{"foo"})) + }) + + t.Run("all requirements must be met", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo", "bar", "baz"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"bar"})) + }) + + t.Run("exact match", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo"}, []string{"bar"})) + }) +} + +func TestIsSqlcodeConstVariable(t *testing.T) { + testCases := []struct { + name string + varname string + expected bool + }{ + {"@Enum prefix", "@EnumFoo", true}, + {"@ENUM_ prefix", "@ENUM_FOO", true}, + {"@enum_ prefix", "@enum_foo", true}, + {"@Const prefix", "@ConstFoo", true}, + {"@CONST_ prefix", "@CONST_FOO", true}, + {"@const_ prefix", "@const_foo", true}, + {"regular variable", "@MyVariable", false}, + {"@Global prefix", "@GlobalVar", false}, + {"no @ prefix", "EnumFoo", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsSqlcodeConstVariable(tc.varname)) + }) + } +} diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go new file mode 100644 index 0000000..e97ec32 --- /dev/null +++ b/sqlparser/pgsql_document.go @@ -0,0 +1,552 @@ +package sqlparser + +import ( + "fmt" + + "github.com/jackc/pgx/v5/stdlib" +) + +var PGSQLStatementTokens = []string{"create"} + +type PGSqlDocument struct { + creates []Create + errors []Error + + Pragma +} + +func (d PGSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d *PGSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.errors = append(d.errors, Error{s.Start(), err.Error()}) + } + + return nil +} + +func (d PGSqlDocument) Creates() []Create { + return d.creates +} + +// Not yet implemented +func (d PGSqlDocument) Declares() []Declare { + return nil +} + +func (d PGSqlDocument) Errors() []Error { + return d.errors +} + +func (d PGSqlDocument) Empty() bool { + return len(d.creates) == 0 +} + +func (d PGSqlDocument) Sort() { + +} + +func (d PGSqlDocument) Include(other Document) { + +} + +func (d PGSqlDocument) WithoutPos() Document { + return &PGSqlDocument{} +} + +// No GO batch separator: +// +// PostgreSQL uses semicolons (;) to separate statements, not GO. +// Multiple CREATE statements can exist in the same file. +// +// No top-level DECLARE: +// +// In PostgreSQL, DECLARE is only used inside function/procedure bodies within BEGIN...END blocks, not as top-level batch statements. +// +// Multiple CREATEs per batch: +// +// Unlike T-SQL which requires procedures/functions to be alone in a batch, PostgreSQL allows multiple CREATE statements separated by semicolons. +// +// Semicolon handling: +// +// The semicolon is a statement terminator, not a batch separator, so parsing continues after encountering one. +// +// Dollar quoting: +// +// PostgreSQL uses $$ or $tag$ for quoting function bodies instead of BEGIN...END (this would be handled in parseCreate). +// +// CREATE OR REPLACE: +// +// PostgreSQL commonly uses CREATE OR REPLACE which would need special handling in parseCreate. +// +// Schema qualification: +// +// PostgreSQL uses schema.object notation rather than [schema].[object]. +func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "create": func(s *Scanner, n *Batch) bool { + // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + c := doc.parseCreate(s, n.CreateStatements) + c.Driver = &stdlib.Driver{} + + // Prepend any leading comments/whitespace + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString + doc.creates = append(doc.creates, c) + + return false + }, + }, + } + + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) + } + + return hasMore + + // var nodes []Unparsed + // var docstring []PosString + // newLineEncounteredInDocstring := false + + // for { + // tt := s.TokenType() + // switch tt { + // case EOFToken: + // return false + // case WhitespaceToken, MultilineCommentToken: + // nodes = append(nodes, CreateUnparsed(s)) + // // do not reset docstring for a single trailing newline + // t := s.Token() + // if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + // newLineEncounteredInDocstring = true + // } else { + // docstring = nil + // } + // s.NextToken() + // case SinglelineCommentToken: + // // Build up a list of single line comments for the "docstring"; + // // it is reset whenever we encounter something else + // docstring = append(docstring, PosString{s.Start(), s.Token()}) + // nodes = append(nodes, CreateUnparsed(s)) + // newLineEncounteredInDocstring = false + // s.NextToken() + // case ReservedWordToken: + // switch s.ReservedWord() { + // case "declare": + // // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // // DECLARE is only used inside function/procedure bodies + // if isFirst { + // doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") + // } + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // docstring = nil + // case "create": + // // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + // createStart := len(doc.creates) + // c := doc.parseCreate(s, createStart) + // c.Driver = &stdlib.Driver{} + + // // Prepend any leading comments/whitespace + // c.Body = append(nodes, c.Body...) + // c.Docstring = docstring + // doc.creates = append(doc.creates, c) + + // // Reset for next statement + // nodes = nil + // docstring = nil + // newLineEncounteredInDocstring = false + // default: + // doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) + // s.NextToken() + // docstring = nil + // } + // case SemicolonToken: + // // PostgreSQL uses semicolons as statement terminators + // // Multiple CREATE statements can exist in same file + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // // Continue parsing - don't return like T-SQL does with GO + // case BatchSeparatorToken: + // // PostgreSQL doesn't use GO batch separators + // // Q: Do we want to use GO batch separators as a feature of sqlcode? + // doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") + // s.NextToken() + // docstring = nil + // default: + // doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) + // s.NextToken() + // docstring = nil + // } + // } +} + +// parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) +// Position is *on* the CREATE token. +// +// PostgreSQL CREATE syntax differences from T-SQL: +// - Supports CREATE OR REPLACE for functions/procedures +// - Uses dollar quoting ($$...$$) or $tag$...$tag$ for function bodies +// - Schema qualification uses dot notation: schema.function_name +// - Double-quoted identifiers preserve case: "MyFunction" +// - Function parameters use different syntax: func(param1 type1, param2 type2) +// - RETURNS clause specifies return type +// - LANGUAGE clause (plpgsql, sql, etc.) is required +// - Function characteristics: IMMUTABLE, STABLE, VOLATILE, PARALLEL SAFE, etc. +// +// We parse until we hit a semicolon or EOF, tracking dependencies on other objects. +func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + var body []Unparsed + + // Copy the CREATE token + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Check for OR REPLACE + // NOTE: "or replace" doesn't make sense within sqlcode as this will be created within a new + // schema. + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "replace" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + } else { + doc.addError(s, "Expected 'REPLACE' after 'OR'") + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) + result.Body = body + return + } + } + + // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) + if s.TokenType() != ReservedWordToken { + doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) + result.Body = body + return + } + + createType := s.ReservedWord() + result.CreateType = createType + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Validate supported CREATE types + switch createType { + case "function", "procedure", "type": + // Supported types + default: + doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) + result.Body = body + return + } + + // Insist on [code] to provide the ability for sqlcode to patch function bodies + // with references to other sqlcode objects. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + RecoverToNextStatementCopying(s, &result.Body, PGSQLStatementTokens) + return + } + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, PGSQLStatementTokens) + if err != nil { + doc.addError(s, err.Error()) + } + if result.QuotedName.String() == "" { + return + } + + // Parse function/procedure signature or type definition + switch createType { + case "function", "procedure": + doc.parseFunctionSignature(s, &body, &result) + case "type": + doc.parseTypeDefinition(s, &body, &result) + } + + // Parse the rest of the CREATE statement body until semicolon or EOF + doc.parseCreateBody(s, &body, &result) + + result.Body = body + return +} + +// parseQualifiedName parses schema-qualified or simple object names +// Supports: simple_name, schema.name, "Quoted Name", schema."Quoted Name" +func (doc *PGSqlDocument) parseQualifiedName(s *Scanner, body *[]Unparsed) string { + var nameParts []string + + for { + switch s.TokenType() { + case UnquotedIdentifierToken: + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + case QuotedIdentifierToken: + // PostgreSQL uses double quotes for case-sensitive identifiers + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + default: + if len(nameParts) == 0 { + return "" + } + // Return the last part as the object name (without schema) + return nameParts[len(nameParts)-1] + } + + // Check for dot separator (schema.object) + if s.TokenType() == DotToken { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + continue + } + + break + } + + if len(nameParts) == 0 { + return "" + } + return nameParts[len(nameParts)-1] +} + +// parseFunctionSignature parses function/procedure parameters and RETURNS clause +func (doc *PGSqlDocument) parseFunctionSignature(s *Scanner, body *[]Unparsed, result *Create) { + // Expect opening parenthesis for parameters + if s.TokenType() != LeftParenToken { + doc.addError(s, "Expected '(' for function parameters") + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Parse parameters until closing parenthesis + parenDepth := 1 + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken: + doc.addError(s, "Unexpected EOF in function parameters") + return + case LeftParenToken: + parenDepth++ + CopyToken(s, body) + s.NextToken() + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + case SemicolonToken: + doc.addError(s, "Unexpected semicolon in function parameters") + return + default: + CopyToken(s, body) + s.NextToken() + } + } + + s.SkipWhitespaceComments() + + // Parse RETURNS clause (for functions, not procedures) + if result.CreateType == "function" { + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "returns" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle RETURNS TABLE(...) + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "table" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == LeftParenToken { + doc.parseReturnTable(s, body) + } + } else { + // Parse simple return type + doc.parseTypeExpression(s, body) + } + } + } +} + +// parseReturnTable parses RETURNS TABLE(...) syntax +func (doc *PGSqlDocument) parseReturnTable(s *Scanner, body *[]Unparsed) { + parenDepth := 0 + for { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + if parenDepth == 0 { + return + } + continue + } + CopyToken(s, body) + s.NextToken() + } +} + +// parseTypeExpression parses PostgreSQL type expressions +// Supports: int, integer, text, varchar(n), numeric(p,s), arrays (int[]), etc. +func (doc *PGSqlDocument) parseTypeExpression(s *Scanner, body *[]Unparsed) { + // Parse base type + if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != ReservedWordToken { + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle array notation: type[] + // if s.TokenType() == LeftBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + + // if s.TokenType() == RightBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + // } + // } + + // Handle type parameters: varchar(100), numeric(10,2) + if s.TokenType() == LeftParenToken { + parenDepth := 1 + CopyToken(s, body) + s.NextToken() + + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + } + CopyToken(s, body) + s.NextToken() + } + } +} + +// parseTypeDefinition parses CREATE TYPE syntax +// Supports: ENUM, composite types, range types +func (doc *PGSqlDocument) parseTypeDefinition(s *Scanner, body *[]Unparsed, result *Create) { + // TYPE definitions use AS keyword + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "as" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Check for ENUM, RANGE, or composite type + if s.TokenType() == ReservedWordToken { + typeKind := s.ReservedWord() + switch typeKind { + case "enum", "range": + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + } + } + } +} + +// parseCreateBody parses the body of a CREATE statement +// Handles dollar-quoted strings, tracks dependencies, continues until semicolon/EOF +func (doc *PGSqlDocument) parseCreateBody(s *Scanner, body *[]Unparsed, result *Create) { + dollarQuoteDepth := 0 + var currentDollarTag string + + for { + switch s.TokenType() { + case EOFToken: + return + case SemicolonToken: + // Statement terminator - we're done + CopyToken(s, body) + s.NextToken() + return + case DollarQuotedStringStartToken: + // PostgreSQL dollar quoting: $$...$$ or $tag$...$tag$ + currentDollarTag = s.Token() + dollarQuoteDepth++ + CopyToken(s, body) + s.NextToken() + case DollarQuotedStringEndToken: + if s.Token() == currentDollarTag { + dollarQuoteDepth-- + } + CopyToken(s, body) + s.NextToken() + if dollarQuoteDepth == 0 { + currentDollarTag = "" + } + case UnquotedIdentifierToken, QuotedIdentifierToken: + // Track dependencies on tables/views/functions + // In PostgreSQL, identifiers can be qualified: schema.object + identifier := s.Token() + + // Check if this might be a dependency (after FROM, JOIN, etc.) + if doc.mightBeDependency(s) { + // Extract just the object name (without schema prefix) + objectName := doc.extractObjectName(identifier) + result.DependsOn = append(result.DependsOn, PosString{s.Start(), objectName}) + } + + CopyToken(s, body) + s.NextToken() + default: + CopyToken(s, body) + s.NextToken() + } + } +} + +// mightBeDependency checks if current context suggests a table/view/function reference +func (doc *PGSqlDocument) mightBeDependency(s *Scanner) bool { + // Simple heuristic: look back for FROM, JOIN, INTO, etc. + // This would need to track parse context for accurate dependency detection + return false // Placeholder - implement context-aware dependency tracking +} + +// extractObjectName extracts object name from schema-qualified identifier +func (doc *PGSqlDocument) extractObjectName(identifier string) string { + // Handle schema.object notation + // For now, return as-is; proper implementation would split on dot + return identifier +} + +func (doc *PGSqlDocument) addError(s *Scanner, err string) { + doc.errors = append(doc.errors, Error{ + s.Start(), err, + }) +} + +func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) + doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") + RecoverToNextStatement(s, PGSQLStatementTokens) + return false +} + +func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { + // PostgreSQL doesn't use GO batch separators + doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") + s.NextToken() +} diff --git a/sqlparser/pgsql_document_test.go b/sqlparser/pgsql_document_test.go new file mode 100644 index 0000000..2d3c8af --- /dev/null +++ b/sqlparser/pgsql_document_test.go @@ -0,0 +1,254 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { + t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "test_func", create.QuotedName.Value) + }) + + t.Run("parses PostgreSQL procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "insert_data", create.QuotedName.Value) + }) + + t.Run("parses CREATE OR REPLACE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses schema-qualified name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "test_func") + }) + + t.Run("parses RETURNS TABLE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("tracks dependencies with schema prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + }) + + t.Run("parses volatility categories", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses PARALLEL SAFE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Types(t *testing.T) { + t.Run("parses composite type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses enum type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses range type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Extensions(t *testing.T) { + t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { + t.Run("parses double-quoted identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "Test Func") + }) + + t.Run("parses case-sensitive identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "TestFunc") + }) +} + +func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { + t.Run("parses array types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "integer[]") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "integer[]", typ.BaseType) + }) + + t.Run("parses serial types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "serial") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "serial", typ.BaseType) + }) + + t.Run("parses text type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "text") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "text", typ.BaseType) + }) + + t.Run("parses jsonb type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "jsonb") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "jsonb", typ.BaseType) + }) + + t.Run("parses uuid type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "uuid") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "uuid", typ.BaseType) + }) +} + +func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { + t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create1 := doc.parseCreate(s, 0) + assert.Equal(t, "test1", create1.QuotedName.Value) + + // Move to next statement + s.NextNonWhitespaceCommentToken() + s.NextNonWhitespaceCommentToken() + + create2 := doc.parseCreate(s, 1) + assert.Equal(t, "test2", create2.QuotedName.Value) + }) +} diff --git a/sqlparser/pragma.go b/sqlparser/pragma.go new file mode 100644 index 0000000..f0ee990 --- /dev/null +++ b/sqlparser/pragma.go @@ -0,0 +1,41 @@ +package sqlparser + +import ( + "fmt" + "strings" +) + +type Pragma struct { + pragmas []string +} + +func (d Pragma) PragmaIncludeIf() []string { + return d.pragmas +} + +func (d *Pragma) parseSinglePragma(s *Scanner) error { + pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) + if pragma == "" { + return nil + } + parts := strings.Split(pragma, " ") + + if len(parts) != 2 || parts[0] != "include-if" { + return fmt.Errorf("Illegal pragma: %s", s.Token()) + } + + d.pragmas = append(d.pragmas, strings.Split(parts[1], ",")...) + return nil +} + +func (d *Pragma) ParsePragmas(s *Scanner) error { + for s.TokenType() == PragmaToken { + err := d.parseSinglePragma(s) + if err != nil { + return err + } + s.NextNonWhitespaceToken() + } + + return nil +} diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index 3103894..18b2b77 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -1,11 +1,13 @@ package sqlparser import ( - "github.com/smasher164/xid" + "fmt" "regexp" "strings" "unicode" "unicode/utf8" + + "github.com/smasher164/xid" ) // dedicated type for reference to file, in case we need to refactor this later.. @@ -40,6 +42,10 @@ type Scanner struct { reservedWord string // in the event that the token is a ReservedWordToken, this contains the lower-case version } +func NewScanner(path FileRef, input string) *Scanner { + return &Scanner{input: input, file: path} +} + type TokenType int func (s *Scanner) TokenType() TokenType { @@ -216,7 +222,9 @@ func (s *Scanner) nextToken() TokenType { return VariableIdentifierToken } else { rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -238,7 +246,10 @@ func (s *Scanner) nextToken() TokenType { // no, it is instead an identifier starting with N... s.scanIdentifier() rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -316,7 +327,7 @@ func (s *Scanner) scanIdentifier() { s.curIndex = len(s.input) } -// DRY helper to handle both '' and ]] escapes +// DRY helper to handle both ” and ]] escapes func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenType, unterminatedTokenType TokenType) TokenType { skipnext := false for i, r := range s.input[s.curIndex:] { @@ -375,6 +386,7 @@ func (s *Scanner) scanWhitespace() TokenType { return WhitespaceToken } +// tsql (mssql) reservered words var reservedWords = map[string]struct{}{ "add": struct{}{}, "external": struct{}{}, diff --git a/sqlparser/tokentype.go b/sqlparser/tokentype.go index 835e605..644d8da 100644 --- a/sqlparser/tokentype.go +++ b/sqlparser/tokentype.go @@ -38,6 +38,10 @@ const ( UnexpectedCharacterToken NonUTF8ErrorToken + // PGSQL specific + DollarQuotedStringStartToken + DollarQuotedStringEndToken + BatchSeparatorToken MalformedBatchSeparatorToken EOFToken @@ -90,6 +94,9 @@ var tokenToDescription = map[TokenType]string{ UnexpectedCharacterToken: "UnexpectedCharacterToken", NonUTF8ErrorToken: "NonUTF8ErrorToken", + DollarQuotedStringStartToken: "DollarQuotedStringEndToken", + DollarQuotedStringEndToken: "DollarQuotedStringEndToken", + // After a lot of back and forth we added the batch separater to the scanner. // We implement sqlcmd's use of the go // do separate batches. sqlcmd will only support GO at the start of diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go new file mode 100644 index 0000000..81f85bb --- /dev/null +++ b/sqlparser/tsql_document.go @@ -0,0 +1,460 @@ +package sqlparser + +import ( + "fmt" + "sort" + "strings" + + mssql "github.com/microsoft/go-mssqldb" +) + +var TSQLStatementTokens = []string{"create", "declare", "go"} + +type TSqlDocument struct { + pragmaIncludeIf []string + creates []Create + declares []Declare + errors []Error + + Pragma +} + +func (d TSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d TSqlDocument) Creates() []Create { + return d.creates +} + +func (d TSqlDocument) Declares() []Declare { + return d.declares +} + +func (d TSqlDocument) Errors() []Error { + return d.errors +} + +func (d *TSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.addError(s, err.Error()) + } + + hasMore := d.parseBatch(s, true) + for hasMore { + hasMore = d.parseBatch(s, false) + } + + return nil +} + +func (d *TSqlDocument) Sort() { + // Do the topological sort; and include any error with it as part + // of `result`, *not* return it as err + sortedCreates, errpos, sortErr := TopologicalSort(d.creates) + + if sortErr != nil { + d.errors = append(d.errors, Error{ + Pos: errpos, + Message: sortErr.Error(), + }) + } else { + d.creates = sortedCreates + } +} + +// Transform a TSqlDocument to remove all Position information; this is used +// to 'unclutter' a DOM to more easily write assertions on it. +func (d TSqlDocument) WithoutPos() Document { + var cs []Create + for _, x := range d.creates { + cs = append(cs, x.WithoutPos()) + } + var ds []Declare + for _, x := range d.declares { + ds = append(ds, x.WithoutPos()) + } + var es []Error + for _, x := range d.errors { + es = append(es, x.WithoutPos()) + } + return &TSqlDocument{ + creates: cs, + declares: ds, + errors: es, + } +} + +func (d *TSqlDocument) Include(other Document) { + // Do not copy pragmaIncludeIf, since that is local to a single file. + // Its contents is also present in each Create. + d.declares = append(d.declares, other.Declares()...) + d.creates = append(d.creates, other.Creates()...) + d.errors = append(d.errors, other.Errors()...) +} + +func (d TSqlDocument) Empty() bool { + return len(d.creates) == 0 || len(d.declares) == 0 +} + +func (d *TSqlDocument) addError(s *Scanner, msg string) { + d.errors = append(d.errors, Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *TSqlDocument) unexpectedTokenError(s *Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + RecoverToNextStatement(s, TSQLStatementTokens) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + RecoverToNextStatement(s, TSQLStatementTokens) + return + } + } + } + + if s.TokenType() != UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *TSqlDocument) parseDeclare(s *Scanner) (result []Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != VariableIdentifierToken { + doc.unexpectedTokenError(s) + RecoverToNextStatement(s, TSQLStatementTokens) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + + s.NextNonWhitespaceCommentToken() + var variableType Type + switch s.TokenType() { + case EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + RecoverToNextStatement(s, TSQLStatementTokens) + } + + switch s.NextNonWhitespaceCommentToken() { + case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + declare := Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: CreateUnparsed(s), + } + result = append(result, declare) + default: + doc.unexpectedTokenError(s) + RecoverToNextStatement(s, TSQLStatementTokens) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + // continuously process tokens until a non-whitespace, non-malformed token is encountered. + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == EOFToken: + return false + case tt == ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.declares = append(doc.declares, d...) + case tt == ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + RecoverToNextStatement(s, TSQLStatementTokens) + case tt == BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + RecoverToNextStatement(s, TSQLStatementTokens) + } + } +} + +func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "declare": func(s *Scanner, n *Batch) bool { + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + + // regardless of errors, go on and parse as far as we get... + return doc.parseDeclareBatch(s) + }, + "create": func(s *Scanner, n *Batch) bool { + // should be start of create procedure or create function... + c := doc.parseCreate(s, n.CreateStatements) + c.Driver = &mssql.Driver{} + + // *prepend* what we saw before getting to the 'create' + n.CreateStatements++ + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString + doc.creates = append(doc.creates, c) + return false + }, + }, + } + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) + } + + return hasMore +} + +// parseCreate parses anything that starts with "create". Position is +// *on* the create token. +// At this stage in sqlcode parser development we're only interested +// in procedures/functions/types as opaque blocks of SQL code where +// we only track dependencies between them and their declared name; +// so we treat them with the same code. We consume until the end of +// the batch; only one declaration allowed per batch. Everything +// parsed here will also be added to `batch`. On any error, copying +// to batch stops / becomes erratic.. +func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + createType := strings.ToLower(s.Token()) + if !(createType == "procedure" || createType == "function" || createType == "type") { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + + result.CreateType = createType + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + return + } + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case "function", "procedure": + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case "type": + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == EOFToken || tt == BatchSeparatorToken: + break tailloop + case tt == QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep, err := ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == ReservedWordToken && s.Token() == "as": + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go new file mode 100644 index 0000000..3d07ad3 --- /dev/null +++ b/sqlparser/tsql_document_test.go @@ -0,0 +1,418 @@ +package sqlparser + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTSqlDocument(t *testing.T) { + t.Run("addError", func(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + require.True(t, doc.HasErrors()) + assert.Equal(t, "test error message", doc.errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) + }) + + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.errors, 2) + assert.Equal(t, "error 1", doc.errors[0].Message) + assert.Equal(t, "error 2", doc.errors[1].Message) + }) + + t.Run("creates error with token text", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() + + doc.unexpectedTokenError(s) + + require.Len(t, doc.errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + }) + }) + + t.Run("parseTypeExpression", func(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "int") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) + + t.Run("parses type with single arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) + + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) + + t.Run("parses type with max", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) + + t.Run("handles invalid arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.errors) + }) + + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "123") + s.NextToken() + + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) + }) + }) +} + +func TestDocument_parseDeclare(t *testing.T) { + t.Run("parses single enum declaration", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumStatus int = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@EnumStatus", declares[0].VariableName) + assert.Equal(t, "int", declares[0].Datatype.BaseType) + assert.Equal(t, "42", declares[0].Literal.RawValue) + }) + + t.Run("parses multiple declarations with comma", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 2) + assert.Equal(t, "@EnumA", declares[0].VariableName) + assert.Equal(t, "@EnumB", declares[1].VariableName) + }) + + t.Run("parses string literal", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "N'test'", declares[0].Literal.RawValue) + }) + + t.Run("errors on invalid variable name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@InvalidName int = 1") + s.NextToken() + + declares := doc.parseDeclare(s) + + // in this case when we detect the missing prefix, + // we add an error and continue parsing the declaration. + // this results with it being added + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "@InvalidName") + }) + + t.Run("errors on missing type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 0) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "type declared explicitly") + }) + + t.Run("errors on missing assignment", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest int") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 0) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "needs to be assigned") + }) + + t.Run("accepts @Global prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@GlobalSetting int = 100") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@GlobalSetting", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) + + t.Run("accepts @Const prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@ConstValue int = 200") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@ConstValue", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) +} + +func TestDocument_parseBatchSeparator(t *testing.T) { + t.Run("parses valid go separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go\n") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.Empty(t, doc.errors) + }) + + t.Run("errors on malformed separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go -- comment") + tt := s.NextToken() + fmt.Printf("%#v %#v\n", s, tt) + doc.parseBatchSeparator(s) + fmt.Printf("%#v\n", s) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "should be alone") + }) +} + +func TestDocument_parseCreate(t *testing.T) { + t.Run("parses simple procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "[TestProc]", create.QuotedName.Value) + assert.NotEmpty(t, create.Body) + }) + + t.Run("parses function", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "[TestFunc]", create.QuotedName.Value) + }) + + t.Run("parses type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create type [code].TestType as table (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + assert.Equal(t, "[TestType]", create.QuotedName.Value) + }) + + t.Run("tracks dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + assert.Equal(t, "[Table2]", create.DependsOn[1].Value) + }) + + t.Run("deduplicates dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 1) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + }) + + t.Run("errors on unsupported create type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create table [code].TestTable (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "only supports creating procedures") + }) + + t.Run("errors on multiple procedures in batch", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 1) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be alone in a batch") + }) + + t.Run("errors on missing code schema", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed by [code]") + }) + + t.Run("allows create index inside procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Empty(t, doc.errors) + }) + + t.Run("stops at batch separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "[Proc]", create.QuotedName.Value) + assert.Equal(t, BatchSeparatorToken, s.TokenType()) + }) + + t.Run("panics if not on create token", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "procedure") + s.NextToken() + + assert.Panics(t, func() { + doc.parseCreate(s, 0) + }) + }) +} + +func TestNextTokenCopyingWhitespace(t *testing.T) { + t.Run("copies whitespace tokens", func(t *testing.T) { + s := NewScanner("test.sql", " \n\t token") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("copies comments", func(t *testing.T) { + s := NewScanner("test.sql", "/* comment */ -- line\ntoken") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.True(t, len(target) >= 2) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", " ") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestCreateUnparsed(t *testing.T) { + t.Run("creates unparsed from scanner", func(t *testing.T) { + s := NewScanner("test.sql", "select") + s.NextToken() + + unparsed := CreateUnparsed(s) + + assert.Equal(t, ReservedWordToken, unparsed.Type) + assert.Equal(t, "select", unparsed.RawValue) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) + }) +} diff --git a/sqlparser/unparsed.go b/sqlparser/unparsed.go new file mode 100644 index 0000000..45a85bf --- /dev/null +++ b/sqlparser/unparsed.go @@ -0,0 +1,25 @@ +package sqlparser + +type Unparsed struct { + Type TokenType + Start, Stop Pos + RawValue string +} + +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), + } +} + +func (u Unparsed) WithoutPos() Unparsed { + return Unparsed{ + Type: u.Type, + Start: Pos{}, + Stop: Pos{}, + RawValue: u.RawValue, + } +} diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 0059908..b384be2 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -4,15 +4,30 @@ import ( "context" "database/sql" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/msdsn" - "github.com/gofrs/uuid" - "io/ioutil" "os" "strings" + "testing" "time" + + "github.com/gofrs/uuid" + _ "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" +) + +type SqlDriverType int + +const ( + SqlDriverMssql SqlDriverType = iota + SqlDriverPgx ) +var sqlDrivers = map[SqlDriverType]string{ + SqlDriverMssql: "sqlserver", + SqlDriverPgx: "pgx", +} + type StdoutLogger struct { } @@ -30,6 +45,26 @@ type Fixture struct { DB *sql.DB DBName string adminDB *sql.DB + Driver SqlDriverType +} + +func (f *Fixture) IsSqlServer() bool { + return f.Driver == SqlDriverMssql +} + +func (f *Fixture) IsPostgresql() bool { + return f.Driver == SqlDriverPgx +} + +// SQL specific quoting syntax +func (f *Fixture) Quote(value string) string { + if f.IsSqlServer() { + return fmt.Sprintf("[%s]", value) + } + if f.IsPostgresql() { + return fmt.Sprintf(`"%s"`, value) + } + return value } func NewFixture() *Fixture { @@ -39,44 +74,91 @@ func NewFixture() *Fixture { defer cancel() dsn := os.Getenv("SQLSERVER_DSN") - if dsn == "" { + if len(dsn) == 0 { panic("Must set SQLSERVER_DSN to run tests") } - dsn = dsn + "&log=3" - mssql.SetLogger(StdoutLogger{}) + if strings.Contains(dsn, "sqlserver") { + // set the logging level + // To enable specific logging levels, you sum the values of the desired flags + // 1: Log errors + // 2: Log messages + // 4: Log rows affected + // 8: Trace SQL statements + // 16: Log statement parameters + // 32: Log transaction begin/end + dsn = dsn + "&log=63" + mssql.SetLogger(StdoutLogger{}) + fixture.Driver = SqlDriverMssql + } + if strings.Contains(dsn, "postgresql") { + fixture.Driver = SqlDriverPgx + // https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-CLIENT-MIN-MESSAGES + dsn = dsn + "&options=-c%20client_min_messages%3DDEBUG5" + } var err error - - fixture.adminDB, err = sql.Open("sqlserver", dsn) + fixture.adminDB, err = sql.Open(sqlDrivers[fixture.Driver], dsn) if err != nil { panic(err) } - fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database [%s]`, fixture.DBName)) - if err != nil { - panic(err) - } - // These settings are just to get "worst-case" for our tests, since snapshot could interfer - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set allow_snapshot_isolation on`, fixture.DBName)) - if err != nil { - panic(err) - } - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set read_committed_snapshot on`, fixture.DBName)) + fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") + dbname := fixture.Quote(fixture.DBName) + qs := fmt.Sprintf(`create database %s`, dbname) + _, err = fixture.adminDB.ExecContext(ctx, qs) if err != nil { + fmt.Printf("Failed to create the (%s) database: %s: %e\n", sqlDrivers[fixture.Driver], dbname, err) panic(err) } - pdsn, _, err := msdsn.Parse(dsn) - if err != nil { - panic(err) + if fixture.IsSqlServer() { + // These settings are just to get "worst-case" for our tests, since snapshot could interfer + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set allow_snapshot_isolation on`, dbname)) + if err != nil { + panic(err) + } + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set read_committed_snapshot on`, dbname)) + if err != nil { + panic(err) + } + + pdsn, err := msdsn.Parse(dsn) + if err != nil { + panic(err) + } + pdsn.Database = fixture.DBName + + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], pdsn.URL().String()) + if err != nil { + panic(err) + } } - pdsn.Database = fixture.DBName - fixture.DB, err = sql.Open("sqlserver", pdsn.URL().String()) - if err != nil { - panic(err) + if fixture.IsPostgresql() { + // TODO use pgx config parser + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) + if err != nil { + panic(err) + } + + var user string + err = fixture.DB.QueryRow(`select current_user`).Scan(&user) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL ON DATABASE "%s" TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON SCHEMA public TO %s;`, user)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`ALTER DATABASE "%s" OWNER TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } } return &fixture @@ -92,13 +174,16 @@ func (f *Fixture) Teardown() { _ = f.DB.Close() f.DB = nil - _, _ = f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database [%s]`, f.DBName)) + _, err := f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database %s`, f.Quote(f.DBName))) + if err != nil { + fmt.Printf("Failed to drop (%s) database %s: %e", sqlDrivers[f.Driver], f.DBName, err) + } _ = f.adminDB.Close() f.adminDB = nil } -func (f *Fixture) RunMigrations() { - migrationSql, err := ioutil.ReadFile("migrations/from0001/0001.changefeed.sql") +func (f *Fixture) RunMigrationFile(filename string) { + migrationSql, err := os.ReadFile(filename) if err != nil { panic(err) } @@ -112,17 +197,22 @@ func (f *Fixture) RunMigrations() { } } -func (f *Fixture) RunMigrationFile(filename string) { - migrationSql, err := ioutil.ReadFile(filename) - if err != nil { - panic(err) - } - parts := strings.Split(string(migrationSql), "\ngo\n") - for _, p := range parts { - _, err = f.DB.Exec(p) - if err != nil { - fmt.Println(p) - panic(err) +func (f *Fixture) RunIfPostgres(t *testing.T, name string, fn func(t *testing.T)) { + t.Run("pgsql", func(t *testing.T) { + if f.IsPostgresql() { + t.Run(name, fn) + } else { + t.Skip() } - } + }) +} + +func (f *Fixture) RunIfMssql(t *testing.T, name string, fn func(t *testing.T)) { + t.Run("mssql", func(t *testing.T) { + if f.IsSqlServer() { + t.Run(name, fn) + } else { + t.Skip() + } + }) } diff --git a/sqltest/sql.go b/sqltest/sql.go index 15d7995..4ad2754 100644 --- a/sqltest/sql.go +++ b/sqltest/sql.go @@ -9,7 +9,11 @@ import ( //go:embed *.sql var sqlfs embed.FS +//go:embed *.pgsql +var pgsqlfx embed.FS + var SQL = sqlcode.MustInclude( sqlcode.Options{}, sqlfs, + pgsqlfx, ) diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 4352eac..b92cc7f 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -2,30 +2,101 @@ package sqltest import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func Test_RowsAffected(t *testing.T) { +func Test_Patch(t *testing.T) { fixture := NewFixture() + ctx := context.Background() defer fixture.Teardown() - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") - ctx := context.Background() + if fixture.IsSqlServer() { + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + } + + if fixture.IsPostgresql() { + fixture.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + _, err := fixture.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) + require.NoError(t, err) + } require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) - - res, err := fixture.DB.ExecContext(ctx, patched) - require.NoError(t, err) - rowsAffected, err := res.RowsAffected() - require.NoError(t, err) - assert.Equal(t, int64(1), rowsAffected) - - schemas := SQL.ListUploaded(ctx, fixture.DB) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) + + fixture.RunIfMssql(t, "returns 1 affected row", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `[code].Test`) + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(1), rowsAffected) + }) + + // TODO: instrument a test table to perform an update operation + fixture.RunIfPostgres(t, "returns 0 affected rows", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `call [code].Test()`) + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + + // postgresql perform does not result with affected rows + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(0), rowsAffected) + }) +} + +func Test_EnsureUploaded(t *testing.T) { + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "uploads schema", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) + require.NoError(t, err) + require.Len(t, schemas, 1) + + }) + + f.RunIfPostgres(t, "uploads schema", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) + require.NoError(t, err) + require.Len(t, schemas, 1) + }) +} + +func Test_DropAndUpload(t *testing.T) { + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) + + f.RunIfPostgres(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) } diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql new file mode 100644 index 0000000..bff51a5 --- /dev/null +++ b/sqltest/test.pgsql @@ -0,0 +1,7 @@ +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; \ No newline at end of file