From fa05d57bc0fd8bdc2f1f0d85db299cd43abb1724 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 6 Nov 2025 18:44:00 +0100 Subject: [PATCH 01/28] Updated testing environment. --- Makefile | 2 ++ docker-compose.test.yml | 27 ++++++++++++++++++--------- dockerfile.test | 6 ++++++ 3 files changed, 26 insertions(+), 9 deletions(-) create mode 100644 Makefile create mode 100644 dockerfile.test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1bfea7f --- /dev/null +++ b/Makefile @@ -0,0 +1,2 @@ +test: + docker compose --progress plain -f docker-compose.test.yml run test \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 94da3ec..56f4fc3 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -1,15 +1,24 @@ services: - # - # mssql - # mssql: image: mcr.microsoft.com/mssql/server:latest - - hostname: mssql - container_name: mssql - network_mode: bridge - ports: - - "1433:1433" + networks: + - mssql environment: ACCEPT_EULA: "Y" SA_PASSWORD: VippsPw1 + healthcheck: + test: ["CMD", "/opt/mssql-tools18/bin/sqlcmd", "-C", "-Usa", "-PVippsPw1", "-Q", "select 1"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - mssql + environment: + SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + depends_on: + mssql: + condition: service_healthy +networks: + mssql: diff --git a/dockerfile.test b/dockerfile.test new file mode 100644 index 0000000..9313912 --- /dev/null +++ b/dockerfile.test @@ -0,0 +1,6 @@ +FROM golang:1.23 AS builder +WORKDIR /sqlcode +ENV GODEBUG="x509negativeserial=1" +COPY . . +RUN go mod tidy +CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file From 351a6dcb5fe01e5717768fd02c284fda3845b33d Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 6 Nov 2025 19:46:56 +0100 Subject: [PATCH 02/28] Working connection. --- Makefile | 5 +- ...mpose.test.yml => docker-compose.mssql.yml | 1 + docker-compose.pgsql.yml | 27 ++++++ dockerfile.test | 2 +- go.mod | 1 + go.sum | 2 + sqlcode.yaml | 6 +- sqltest/fixture.go | 93 ++++++++++++++----- 8 files changed, 108 insertions(+), 29 deletions(-) rename docker-compose.test.yml => docker-compose.mssql.yml (94%) create mode 100644 docker-compose.pgsql.yml diff --git a/Makefile b/Makefile index 1bfea7f..f6980da 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ test: - docker compose --progress plain -f docker-compose.test.yml run test \ No newline at end of file + docker compose --progress plain -f docker-compose.mssql.yml run test + +test_pgsql: + docker compose --progress plain -f docker-compose.pgsql.yml run test \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.mssql.yml similarity index 94% rename from docker-compose.test.yml rename to docker-compose.mssql.yml index 56f4fc3..84618fc 100644 --- a/docker-compose.test.yml +++ b/docker-compose.mssql.yml @@ -17,6 +17,7 @@ services: - mssql environment: SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 + SQLSERVER_DRIVER: sqlserver depends_on: mssql: condition: service_healthy diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml new file mode 100644 index 0000000..6391e0c --- /dev/null +++ b/docker-compose.pgsql.yml @@ -0,0 +1,27 @@ +services: + postgres: + image: postgres + networks: + - postgres + environment: + POSTGRES_PASSWORD: VippsPw1 + POSTGRES_USER: sa + POSTGRES_DB: master + healthcheck: + test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"] + interval: 1s + retries: 20 + test: + build: + dockerfile: dockerfile.test + networks: + - postgres + environment: + SQLSERVER_DSN: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable + SQLSERVER_DRIVER: postgres + GODEBUG: "x509negativeserial=1" + depends_on: + postgres: + condition: service_healthy +networks: + postgres: diff --git a/dockerfile.test b/dockerfile.test index 9313912..f4a199f 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,4 +1,4 @@ -FROM golang:1.23 AS builder +FROM golang:1.25.1 AS builder WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . diff --git a/go.mod b/go.mod index fa5e2c6..65a2812 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.9 // indirect diff --git a/go.sum b/go.sum index d0ecd83..6433e10 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,8 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= diff --git a/sqlcode.yaml b/sqlcode.yaml index 549c23f..8adefb7 100644 --- a/sqlcode.yaml +++ b/sqlcode.yaml @@ -1,6 +1,8 @@ databases: - localtest: - connection: sqlserver://localhost:1433?database=foo&user id=foouser&password=FooPasswd1 + mssql: + connection: sqlserver://mssql:1433?database=foo&user id=foouser&password=FooPasswd1 + pgsql: + connection: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable # One option is to list other paths to include ('dependencies') here. diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 0059908..8c86651 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -3,14 +3,17 @@ package sqltest import ( "context" "database/sql" + "database/sql/driver" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/msdsn" - "github.com/gofrs/uuid" "io/ioutil" "os" "strings" "time" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/gofrs/uuid" + pgsql "github.com/lib/pq" ) type StdoutLogger struct { @@ -30,6 +33,20 @@ type Fixture struct { DB *sql.DB DBName string adminDB *sql.DB + Driver driver.Driver +} + +func (f *Fixture) Quote(value string) string { + var ms mssql.Driver + var pg pgsql.Driver + + if f.Driver == &ms { + return fmt.Sprintf("[%s]", value) + } + if f.Driver == &pg { + return fmt.Sprintf(`"%s"`, value) + } + return value } func NewFixture() *Fixture { @@ -39,44 +56,70 @@ func NewFixture() *Fixture { defer cancel() dsn := os.Getenv("SQLSERVER_DSN") - if dsn == "" { + if len(dsn) == 0 { panic("Must set SQLSERVER_DSN to run tests") } - dsn = dsn + "&log=3" - mssql.SetLogger(StdoutLogger{}) + driver := os.Getenv("SQLSERVER_DRIVER") + if len(driver) == 0 { + panic("Must set SQLSERVER_DRIVER to run tests") + } + + switch driver { + case "sqlserver": + // set the logging level + dsn = dsn + "&log=3" + mssql.SetLogger(StdoutLogger{}) + case "postgres": + break + } var err error - fixture.adminDB, err = sql.Open("sqlserver", dsn) + fixture.adminDB, err = sql.Open(driver, dsn) if err != nil { panic(err) } - fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") + // store a reference to the type of sql driver + fixture.Driver = fixture.adminDB.Driver() - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database [%s]`, fixture.DBName)) - if err != nil { - panic(err) - } - // These settings are just to get "worst-case" for our tests, since snapshot could interfer - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set allow_snapshot_isolation on`, fixture.DBName)) - if err != nil { - panic(err) - } - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database [%s] set read_committed_snapshot on`, fixture.DBName)) + fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") + dbname := fixture.Quote(fixture.DBName) + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database %s`, dbname)) if err != nil { + fmt.Printf("Failed to create the database: %s for the %s driver\n", dbname, driver) panic(err) } - pdsn, _, err := msdsn.Parse(dsn) - if err != nil { - panic(err) + if driver == "sqlserver" { + // These settings are just to get "worst-case" for our tests, since snapshot could interfer + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set allow_snapshot_isolation on`, dbname)) + if err != nil { + panic(err) + } + _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set read_committed_snapshot on`, dbname)) + if err != nil { + panic(err) + } + + pdsn, _, err := msdsn.Parse(dsn) + if err != nil { + panic(err) + } + pdsn.Database = fixture.DBName + + fixture.DB, err = sql.Open(driver, pdsn.URL().String()) + if err != nil { + panic(err) + } } - pdsn.Database = fixture.DBName - fixture.DB, err = sql.Open("sqlserver", pdsn.URL().String()) - if err != nil { - panic(err) + if driver == "postgres" { + // TODO + fixture.DB, err = sql.Open(driver, strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) + if err != nil { + panic(err) + } } return &fixture From b1b982d2fbb183e1d8aa637d77585f20efbba71a Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 18:49:47 +0100 Subject: [PATCH 03/28] Updated fixture to use pgx. Simplified how we introspect the DSN to determine the driver type. Allow for future expansion/testing with other drivers. Preparring to write sqlcode migration for Postgres. --- Makefile | 5 +- deployable.go | 1 + go.mod | 5 ++ go.sum | 12 +++++ migrations/0003.sqlcode.pgsql | 0 sqltest/fixture.go | 93 ++++++++++++++++++----------------- sqltest/sqlcode_test.go | 24 +++++++++ 7 files changed, 94 insertions(+), 46 deletions(-) create mode 100644 migrations/0003.sqlcode.pgsql diff --git a/Makefile b/Makefile index f6980da..fbcbdeb 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,7 @@ -test: +test: test_mssql test_pgsql + + +test_mssql: docker compose --progress plain -f docker-compose.mssql.yml run test test_pgsql: diff --git a/deployable.go b/deployable.go index 7e2b178..dcf2726 100644 --- a/deployable.go +++ b/deployable.go @@ -158,6 +158,7 @@ select @retcode; } defer func() { + // TODO: This returns an error if the lock is already released _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, sql.Named("Resource", lockResourceName), sql.Named("LockOwner", "Session"), diff --git a/go.mod b/go.mod index 65a2812..b482681 100644 --- a/go.mod +++ b/go.mod @@ -22,11 +22,16 @@ require ( github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.6 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/lib/pq v1.10.9 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.43.0 // indirect + golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/text v0.30.0 // indirect ) diff --git a/go.sum b/go.sum index 6433e10..17eb045 100644 --- a/go.sum +++ b/go.sum @@ -22,6 +22,14 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= @@ -40,6 +48,7 @@ github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4 github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= @@ -53,6 +62,8 @@ golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -71,6 +82,7 @@ golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql new file mode 100644 index 0000000..e69de29 diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 8c86651..b38ac06 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -3,9 +3,7 @@ package sqltest import ( "context" "database/sql" - "database/sql/driver" "fmt" - "io/ioutil" "os" "strings" "time" @@ -13,9 +11,22 @@ import ( mssql "github.com/denisenkom/go-mssqldb" "github.com/denisenkom/go-mssqldb/msdsn" "github.com/gofrs/uuid" - pgsql "github.com/lib/pq" + _ "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" ) +type SqlDriverType int + +const ( + SqlDriverDenisen SqlDriverType = iota + SqlDriverPgx +) + +var sqlDrivers = map[SqlDriverType]string{ + SqlDriverDenisen: "sqlserver", + SqlDriverPgx: "pgx", +} + type StdoutLogger struct { } @@ -33,17 +44,23 @@ type Fixture struct { DB *sql.DB DBName string adminDB *sql.DB - Driver driver.Driver + Driver SqlDriverType } -func (f *Fixture) Quote(value string) string { - var ms mssql.Driver - var pg pgsql.Driver +func (f *Fixture) IsSqlServer() bool { + return f.Driver == SqlDriverDenisen +} - if f.Driver == &ms { +func (f *Fixture) IsPostgresql() bool { + return f.Driver == SqlDriverPgx +} + +// SQL specific quoting syntax +func (f *Fixture) Quote(value string) string { + if f.IsSqlServer() { return fmt.Sprintf("[%s]", value) } - if f.Driver == &pg { + if f.IsPostgresql() { return fmt.Sprintf(`"%s"`, value) } return value @@ -60,38 +77,39 @@ func NewFixture() *Fixture { panic("Must set SQLSERVER_DSN to run tests") } - driver := os.Getenv("SQLSERVER_DRIVER") - if len(driver) == 0 { - panic("Must set SQLSERVER_DRIVER to run tests") - } - - switch driver { - case "sqlserver": + if strings.Contains(dsn, "sqlserver") { // set the logging level - dsn = dsn + "&log=3" + // To enable specific logging levels, you sum the values of the desired flags + // 1: Log errors + // 2: Log messages + // 4: Log rows affected + // 8: Trace SQL statements + // 16: Log statement parameters + // 32: Log transaction begin/end + dsn = dsn + "&log=63" mssql.SetLogger(StdoutLogger{}) - case "postgres": - break + fixture.Driver = SqlDriverDenisen + } + if strings.Contains(dsn, "postgresql") { + fixture.Driver = SqlDriverPgx } var err error - - fixture.adminDB, err = sql.Open(driver, dsn) + fixture.adminDB, err = sql.Open(sqlDrivers[fixture.Driver], dsn) if err != nil { panic(err) } - // store a reference to the type of sql driver - fixture.Driver = fixture.adminDB.Driver() fixture.DBName = strings.ReplaceAll(uuid.Must(uuid.NewV4()).String(), "-", "") dbname := fixture.Quote(fixture.DBName) - _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`create database %s`, dbname)) + qs := fmt.Sprintf(`create database %s`, dbname) + _, err = fixture.adminDB.ExecContext(ctx, qs) if err != nil { - fmt.Printf("Failed to create the database: %s for the %s driver\n", dbname, driver) + fmt.Printf("Failed to create the (%s) database: %s\n", sqlDrivers[fixture.Driver], dbname) panic(err) } - if driver == "sqlserver" { + if fixture.IsSqlServer() { // These settings are just to get "worst-case" for our tests, since snapshot could interfer _, err = fixture.adminDB.ExecContext(ctx, fmt.Sprintf(`alter database %s set allow_snapshot_isolation on`, dbname)) if err != nil { @@ -108,15 +126,15 @@ func NewFixture() *Fixture { } pdsn.Database = fixture.DBName - fixture.DB, err = sql.Open(driver, pdsn.URL().String()) + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], pdsn.URL().String()) if err != nil { panic(err) } } - if driver == "postgres" { + if fixture.IsPostgresql() { // TODO - fixture.DB, err = sql.Open(driver, strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) + fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) if err != nil { panic(err) } @@ -140,23 +158,8 @@ func (f *Fixture) Teardown() { f.adminDB = nil } -func (f *Fixture) RunMigrations() { - migrationSql, err := ioutil.ReadFile("migrations/from0001/0001.changefeed.sql") - if err != nil { - panic(err) - } - parts := strings.Split(string(migrationSql), "\ngo\n") - for _, p := range parts { - _, err = f.DB.Exec(p) - if err != nil { - fmt.Println(p) - panic(err) - } - } -} - func (f *Fixture) RunMigrationFile(filename string) { - migrationSql, err := ioutil.ReadFile(filename) + migrationSql, err := os.ReadFile(filename) if err != nil { panic(err) } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 4352eac..12e06f1 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -11,6 +11,7 @@ import ( func Test_RowsAffected(t *testing.T) { fixture := NewFixture() defer fixture.Teardown() + // if sql else pgsql fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") ctx := context.Background() @@ -29,3 +30,26 @@ func Test_RowsAffected(t *testing.T) { require.Equal(t, 6, schemas[0].Objects) require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) } + +func Test_EnsureUploaded(t *testing.T) { + fixture := NewFixture() + defer fixture.Teardown() + + t.Run("mssql", func(t *testing.T) { + if !fixture.IsSqlServer() { + t.Skip() + } + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + + ctx := context.Background() + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + }) + + t.Run("pgsql", func(t *testing.T) { + if !fixture.IsPostgresql() { + t.Skip() + } + + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + }) +} From 2d9462f4cf5fd1899a631a6c47c6c78329c8c48a Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 19:50:31 +0100 Subject: [PATCH 04/28] Wrote initial sqlcode migration for postgres. --- docker-compose.pgsql.yml | 1 - migrations/0003.sqlcode.pgsql | 166 ++++++++++++++++++++++++++++++++++ sqltest/fixture.go | 11 ++- sqltest/sqlcode_test.go | 37 +++++--- 4 files changed, 196 insertions(+), 19 deletions(-) diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml index 6391e0c..bdb0f0d 100644 --- a/docker-compose.pgsql.yml +++ b/docker-compose.pgsql.yml @@ -18,7 +18,6 @@ services: - postgres environment: SQLSERVER_DSN: postgresql://sa:VippsPw1@postgres:5432/master?sslmode=disable - SQLSERVER_DRIVER: postgres GODEBUG: "x509negativeserial=1" depends_on: postgres: diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index e69de29..b26b8ed 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -0,0 +1,166 @@ +-- ====================================================================== +-- Create users and roles +-- ====================================================================== +do $$ +begin + -- This role will own the sqlcode schemas, so that created functions etc. + -- are owned by a role without permissions; this means functions/procedures + -- will not get more permissions than the caller already has (unless you use + -- SECURITY DEFINER somewhere). + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-user-with-no-permissions' + ) then + create role "sqlcode-user-with-no-permissions" nologin; + end if; + + -- This role will be granted execute (usage) permissions on all sqlcode schemas; + -- useful e.g. for humans logging in to debug. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-execute-role' + ) then + create role "sqlcode-execute-role"; + end if; + + -- Role for calling CreateCodeSchema / DropCodeSchema; the role will also be granted + -- control over all schemas created this way. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-role' + ) then + create role "sqlcode-deploy-role"; + end if; + + -- Make a role that *only* has this deploy role. During deploys we SET ROLE to this + -- so that we can more safely deploy code with restricted permissions. + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-deploy-sandbox-user' + ) then + create role "sqlcode-deploy-sandbox-user" nologin; + end if; +end; +$$; + +-- ====================================================================== +-- grant permissions +-- ====================================================================== + +do $$ +begin + -- grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user" + if not exists ( + select 1 + from pg_auth_members m + join pg_roles r_role on r_role.oid = m.roleid + join pg_roles r_member on r_member.oid = m.member + where r_role.rolname = 'sqlcode-deploy-role' + and r_member.rolname = 'sqlcode-deploy-sandbox-user' + ) then + grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user"; + end if; + +end; +$$; + +-- ====================================================================== +-- create schema +-- ====================================================================== + +-- Base schema to hold the procedures etc. +do $$ +begin + if not exists ( + select 1 from pg_namespace where nspname = 'sqlcode' + ) then + create schema sqlcode; + end if; +end; +$$; + +-- ====================================================================== +-- create procedures +-- ====================================================================== + +create or replace procedure sqlcode.createcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); +begin + -- create the schema owned by "sqlcode-user-with-no-permissions" + execute format( + 'create schema %I authorization %I', + schemaname, + 'sqlcode-user-with-no-permissions' + ); + + -- grant schema privileges + execute format( + 'grant usage on schema %I to %I', + schemaname, + 'sqlcode-execute-role' + ); + + execute format( + 'grant usage, create on schema %I to %I', + schemaname, + 'sqlcode-deploy-role' + ); + +exception + when others then + raise; +end; +$$; + +-- ====================================================================== +-- procedure: sqlcode.dropcodeschema +-- ====================================================================== + +create or replace procedure sqlcode.dropcodeschema(schemasuffix varchar) +language plpgsql +security definer +as $$ +declare + schemaname text := format('code@%s', schemasuffix); + schema_exists boolean; +begin + -- check schema existence + select exists ( + select 1 + from pg_namespace + where nspname = schemaname + ) into schema_exists; + + if not schema_exists then + raise exception 'schema "%" not found', schemaname; + end if; + + -- drop the schema and all objects within it + execute format('drop schema %I cascade', schemaname); + +exception + when others then + raise; +end; +$$; + +-- ====================================================================== +-- privileges on the procedures and base schema +-- ====================================================================== + +grant execute on procedure sqlcode.createcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant execute on procedure sqlcode.dropcodeschema(varchar) + to "sqlcode-deploy-role"; + +grant usage, create on schema sqlcode + to "sqlcode-deploy-role"; diff --git a/sqltest/fixture.go b/sqltest/fixture.go index b38ac06..82f395d 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -92,6 +92,8 @@ func NewFixture() *Fixture { } if strings.Contains(dsn, "postgresql") { fixture.Driver = SqlDriverPgx + // https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-CLIENT-MIN-MESSAGES + dsn = dsn + "&options=-c%20client_min_messages%3DDEBUG5" } var err error @@ -105,7 +107,7 @@ func NewFixture() *Fixture { qs := fmt.Sprintf(`create database %s`, dbname) _, err = fixture.adminDB.ExecContext(ctx, qs) if err != nil { - fmt.Printf("Failed to create the (%s) database: %s\n", sqlDrivers[fixture.Driver], dbname) + fmt.Printf("Failed to create the (%s) database: %s: %e\n", sqlDrivers[fixture.Driver], dbname, err) panic(err) } @@ -133,7 +135,7 @@ func NewFixture() *Fixture { } if fixture.IsPostgresql() { - // TODO + // TODO use pgx config parser fixture.DB, err = sql.Open(sqlDrivers[fixture.Driver], strings.ReplaceAll(dsn, "/master", "/"+fixture.DBName)) if err != nil { panic(err) @@ -153,7 +155,10 @@ func (f *Fixture) Teardown() { _ = f.DB.Close() f.DB = nil - _, _ = f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database [%s]`, f.DBName)) + _, err := f.adminDB.ExecContext(ctx, fmt.Sprintf(`drop database %s`, f.Quote(f.DBName))) + if err != nil { + fmt.Printf("Failed to drop (%s) database %s: %e", sqlDrivers[f.Driver], f.DBName, err) + } _ = f.adminDB.Close() f.adminDB = nil } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 12e06f1..9de22a3 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -11,24 +11,31 @@ import ( func Test_RowsAffected(t *testing.T) { fixture := NewFixture() defer fixture.Teardown() - // if sql else pgsql - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + t.Run("mssql", func(t *testing.T) { + if !fixture.IsSqlServer() { + t.Skip() + } + + fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") - ctx := context.Background() + ctx := context.Background() + + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + patched := SQL.Patch(`[code].Test`) - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(1), rowsAffected) - res, err := fixture.DB.ExecContext(ctx, patched) - require.NoError(t, err) - rowsAffected, err := res.RowsAffected() - require.NoError(t, err) - assert.Equal(t, int64(1), rowsAffected) + schemas := SQL.ListUploaded(ctx, fixture.DB) + require.Len(t, schemas, 1) + require.Equal(t, 6, schemas[0].Objects) + require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) + + }) - schemas := SQL.ListUploaded(ctx, fixture.DB) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) } func Test_EnsureUploaded(t *testing.T) { @@ -50,6 +57,6 @@ func Test_EnsureUploaded(t *testing.T) { t.Skip() } - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") }) } From 66431588fd171b1c6c0a8f7d43bf001021d4160e Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 20:06:03 +0100 Subject: [PATCH 05/28] Add security definer role. --- migrations/0003.sqlcode.pgsql | 62 +++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index b26b8ed..a338627 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -1,12 +1,9 @@ -- ====================================================================== --- Create users and roles +-- create users and roles -- ====================================================================== do $$ begin - -- This role will own the sqlcode schemas, so that created functions etc. - -- are owned by a role without permissions; this means functions/procedures - -- will not get more permissions than the caller already has (unless you use - -- SECURITY DEFINER somewhere). + -- role that will own the sqlcode schemas (actual code schemas), with no login if not exists ( select 1 from pg_roles @@ -15,8 +12,16 @@ begin create role "sqlcode-user-with-no-permissions" nologin; end if; - -- This role will be granted execute (usage) permissions on all sqlcode schemas; - -- useful e.g. for humans logging in to debug. + -- role that owns the management schema/procedures (security definer) + if not exists ( + select 1 + from pg_roles + where rolname = 'sqlcode-definer-role' + ) then + create role "sqlcode-definer-role" nologin; + end if; + + -- role that gets execute / usage on code schemas (for humans debugging etc.) if not exists ( select 1 from pg_roles @@ -25,8 +30,8 @@ begin create role "sqlcode-execute-role"; end if; - -- Role for calling CreateCodeSchema / DropCodeSchema; the role will also be granted - -- control over all schemas created this way. + -- role for calling createcodeschema / dropcodeschema; + -- this role does not own the procedures, it only calls them. if not exists ( select 1 from pg_roles @@ -35,8 +40,7 @@ begin create role "sqlcode-deploy-role"; end if; - -- Make a role that *only* has this deploy role. During deploys we SET ROLE to this - -- so that we can more safely deploy code with restricted permissions. + -- sandbox role used during deploys, which only has sqlcode-deploy-role if not exists ( select 1 from pg_roles @@ -48,7 +52,7 @@ end; $$; -- ====================================================================== --- grant permissions +-- grant permissions / role memberships -- ====================================================================== do $$ @@ -64,27 +68,25 @@ begin ) then grant "sqlcode-deploy-role" to "sqlcode-deploy-sandbox-user"; end if; - end; $$; -- ====================================================================== --- create schema +-- create schema for management code (owner = definer role) -- ====================================================================== --- Base schema to hold the procedures etc. do $$ begin if not exists ( select 1 from pg_namespace where nspname = 'sqlcode' ) then - create schema sqlcode; + create schema sqlcode authorization "sqlcode-definer-role"; end if; end; $$; -- ====================================================================== --- create procedures +-- create procedures (security definer) -- ====================================================================== create or replace procedure sqlcode.createcodeschema(schemasuffix varchar) @@ -94,6 +96,9 @@ as $$ declare schemaname text := format('code@%s', schemasuffix); begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + -- create the schema owned by "sqlcode-user-with-no-permissions" execute format( 'create schema %I authorization %I', @@ -120,18 +125,17 @@ exception end; $$; --- ====================================================================== --- procedure: sqlcode.dropcodeschema --- ====================================================================== - create or replace procedure sqlcode.dropcodeschema(schemasuffix varchar) language plpgsql security definer as $$ declare - schemaname text := format('code@%s', schemasuffix); - schema_exists boolean; + schemaname text := format('code@%s', schemasuffix); + schema_exists boolean; begin + -- harden search_path for security-definer (optional but recommended) + perform set_config('search_path', 'pg_catalog', true); + -- check schema existence select exists ( select 1 @@ -152,15 +156,25 @@ exception end; $$; +-- ensure procedures are owned by the definer role +alter procedure sqlcode.createcodeschema(varchar) + owner to "sqlcode-definer-role"; + +alter procedure sqlcode.dropcodeschema(varchar) + owner to "sqlcode-definer-role"; + -- ====================================================================== -- privileges on the procedures and base schema -- ====================================================================== +-- allow deploy role to call the management procedures grant execute on procedure sqlcode.createcodeschema(varchar) to "sqlcode-deploy-role"; grant execute on procedure sqlcode.dropcodeschema(varchar) to "sqlcode-deploy-role"; -grant usage, create on schema sqlcode +-- usually deploy role does not need create in the sqlcode management schema +-- (the procedures handle creation in separate "code@..." schemas) +grant usage on schema sqlcode to "sqlcode-deploy-role"; From fd8c76f40873e91930dc57d3fe083182aac06b50 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Wed, 3 Dec 2025 21:51:32 +0100 Subject: [PATCH 06/28] [wip] working through changes for EnsureUploaded to support postgresql. --- dbintf.go | 2 + dbops.go | 16 ++++- deployable.go | 106 ++++++++++++++++++++++++---------- docker-compose.pgsql.yml | 1 + migrations/0003.sqlcode.pgsql | 46 +++++++++++++++ sqltest/sqlcode_test.go | 21 +++++++ 6 files changed, 161 insertions(+), 31 deletions(-) diff --git a/dbintf.go b/dbintf.go index 8257e11..1495942 100644 --- a/dbintf.go +++ b/dbintf.go @@ -3,6 +3,7 @@ package sqlcode import ( "context" "database/sql" + "database/sql/driver" ) type DB interface { @@ -11,6 +12,7 @@ type DB interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row Conn(ctx context.Context) (*sql.Conn, error) BeginTx(ctx context.Context, txOptions *sql.TxOptions) (*sql.Tx, error) + Driver() driver.Driver } var _ DB = &sql.DB{} diff --git a/dbops.go b/dbops.go index 05e5a88..0293a2a 100644 --- a/dbops.go +++ b/dbops.go @@ -3,11 +3,25 @@ package sqlcode import ( "context" "database/sql" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { var schemaID int - err := dbc.QueryRowContext(ctx, `select isnull(schema_id(@p1), 0)`, SchemaName(schemasuffix)).Scan(&schemaID) + + driver := dbc.Driver() + var qs string + + if _, ok := driver.(*mssql.Driver); ok { + qs = `select isnull(schema_id(@p1), 0)` + } + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select coalesce((select oid from pg_namespace where nspname = $1),0)` + } + + err := dbc.QueryRowContext(ctx, qs, SchemaName(schemasuffix)).Scan(&schemaID) if err != nil { return false, err } diff --git a/deployable.go b/deployable.go index dcf2726..c1b38fe 100644 --- a/deployable.go +++ b/deployable.go @@ -11,6 +11,9 @@ import ( "time" mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + pgxstdlib "github.com/jackc/pgx/v5/stdlib" "github.com/vippsas/sqlcode/sqlparser" ) @@ -77,21 +80,22 @@ func impersonate(ctx context.Context, dbc DB, username string, f func(conn *sql. // Upload will create and upload the schema; resulting in an error // if the schema already exists func (d *Deployable) Upload(ctx context.Context, dbc DB) error { - // First, impersonate a user with minimal privileges to get at least - // some level of sandboxing so that migration scripts can't do anything - // the caller didn't expect them to. - return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { + driver := dbc.Driver() + qs := make(map[string][]interface{}, 1) + + var uploadFunc = func(conn *sql.Conn) error { tx, err := conn.BeginTx(ctx, nil) if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.CreateCodeSchema`, - sql.Named("schemasuffix", d.SchemaSuffix), - ) - if err != nil { - _ = tx.Rollback() - return err + for q, args := range qs { + _, err = tx.ExecContext(ctx, q, args...) + + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("failed to execute (%s) with arg(%s) in schema %s: %w", q, args, d.SchemaSuffix, err) + } } preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) @@ -123,8 +127,36 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { return nil - }) + } + + if _, ok := driver.(*mssql.Driver); ok { + // First, impersonate a user with minimal privileges to get at least + // some level of sandboxing so that migration scripts can't do anything + // the caller didn't expect them to. + qs["sqlcode.CreateCodeSchema"] = []interface { + }{ + sql.Named("schemasuffix", d.SchemaSuffix), + } + return impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", uploadFunc) + } + + if _, ok := driver.(*stdlib.Driver); ok { + qs[`set role "sqlcode-deploy-sandbox-user"`] = nil + qs[`call sqlcode.createcodeschema(@schemasuffix)`] = []interface{}{ + pgx.NamedArgs{"schemasuffix": d.SchemaSuffix}, + } + conn, err := dbc.Conn(ctx) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + return uploadFunc(conn) + } + + return fmt.Errorf("failed to determine sql driver to upload schema: %s", d.SchemaSuffix) } // EnsureUploaded checks that the schema with the suffix already exists, @@ -137,37 +169,51 @@ func (d *Deployable) EnsureUploaded(ctx context.Context, dbc DB) error { return nil } + driver := dbc.Driver() lockResourceName := "sqlcode.EnsureUploaded/" + d.SchemaSuffix + var lockRetCode int + var lockQs string + var unlockQs string + var err error + // When a lock is opened with the Transaction lock owner, // that lock is released when the transaction is committed or rolled back. - var lockRetCode int - err := dbc.QueryRowContext(ctx, ` -declare @retcode int; -exec @retcode = sp_getapplock @Resource = @resource, @LockMode = 'Shared', @LockOwner = 'Session', @LockTimeout = @timeoutMs; -select @retcode; -`, - sql.Named("resource", lockResourceName), - sql.Named("timeoutMs", 20000), - ).Scan(&lockRetCode) + if _, ok := driver.(*pgxstdlib.Driver); ok { + lockQs = `select sqlcode.get_applock(@resource, @timeout)` + unlockQs = `select sqlcode.release_applock(@resource)` + + err = dbc.QueryRowContext(ctx, lockQs, pgx.NamedArgs{ + "resource": lockResourceName, + "timeoutMs": 20000, + }).Scan(&lockRetCode) + + defer func() { + dbc.ExecContext(ctx, unlockQs, pgx.NamedArgs{"resource": lockResourceName}) + }() + } + + if _, ok := driver.(*mssql.Driver); ok { + // TODO + + defer func() { + // TODO: This returns an error if the lock is already released + _, _ = dbc.ExecContext(ctx, unlockQs, + sql.Named("Resource", lockResourceName), + sql.Named("LockOwner", "Session"), + ) + }() + } + if err != nil { return err } if lockRetCode < 0 { return errors.New("was not able to get lock before timeout") } - - defer func() { - // TODO: This returns an error if the lock is already released - _, _ = dbc.ExecContext(ctx, `sp_releaseapplock`, - sql.Named("Resource", lockResourceName), - sql.Named("LockOwner", "Session"), - ) - }() - exists, err := Exists(ctx, dbc, d.SchemaSuffix) if err != nil { - return err + return fmt.Errorf("unable to determine if schema %s exists: %w", d.SchemaSuffix, err) } if exists { diff --git a/docker-compose.pgsql.yml b/docker-compose.pgsql.yml index bdb0f0d..a6d7a0b 100644 --- a/docker-compose.pgsql.yml +++ b/docker-compose.pgsql.yml @@ -7,6 +7,7 @@ services: POSTGRES_PASSWORD: VippsPw1 POSTGRES_USER: sa POSTGRES_DB: master + PGOPTIONS: "-c log_error_verbosity=verbose -c log_statement=all" healthcheck: test: ["CMD-SHELL", "pg_isready", "-d", "db_prod"] interval: 1s diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index a338627..98f9228 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -156,6 +156,52 @@ exception end; $$; +-- similar behaviour as mssql getapplock +-- PostgreSQL advisory locks are session-based by default +create or replace function sqlcode.get_applock( + resource text, + timeout_ms integer default 0 +) +returns integer +language plpgsql +as $$ +declare + resource_key bigint; + acquired boolean; + waited_ms integer := 0; +begin + -- convert string to advisory-lock key + select hashtext(resource) into resource_key; + + -- attempt lock with timeout loop + loop + select pg_try_advisory_lock_shared(resource_key) + into acquired; + + if acquired then + return 1; -- lock acquired (success) + end if; + + if waited_ms >= timeout_ms then + return 0; -- timeout + end if; + + perform pg_sleep(0.01); -- sleep 10 ms + waited_ms := waited_ms + 10; + end loop; + + return null; -- safety fallback (should never hit) +end; +$$; + +create or replace function sqlcode.release_applock(resource text) +returns boolean +language sql +as $$ + select pg_advisory_unlock_shared(hashtext(resource)); +$$; + + -- ensure procedures are owned by the definer role alter procedure sqlcode.createcodeschema(varchar) owner to "sqlcode-definer-role"; diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 9de22a3..f343c07 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -58,5 +59,25 @@ func Test_EnsureUploaded(t *testing.T) { } fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") + + ctx := context.Background() + + _, err := fixture.adminDB.Exec(`grant create on database @database to "sqlcode-definer-role"`, + pgx.NamedArgs{"database": fixture.DBName}) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + patched := SQL.Patch(`[code].Test`) + + res, err := fixture.DB.ExecContext(ctx, patched) + require.NoError(t, err) + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(1), rowsAffected) + + schemas := SQL.ListUploaded(ctx, fixture.DB) + require.Len(t, schemas, 1) + require.Equal(t, 6, schemas[0].Objects) + require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) }) } From b6681a20faa7263b03ee4ab912fbedbb46cd29d7 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 4 Dec 2025 19:19:38 +0100 Subject: [PATCH 07/28] Working EnsureUpload! --- cli/cmd/build.go | 4 +- deployable.go | 72 ++++++++++++++++++++++++++--------- migrations/0003.sqlcode.pgsql | 5 ++- error.go => mssql_error.go | 7 ++-- preprocess.go | 22 +++++++++-- sqlparser/dom.go | 7 +++- sqlparser/parser.go | 13 ++++++- sqlparser/parser_test.go | 18 +++++++++ sqltest/sql.go | 4 ++ sqltest/sqlcode_test.go | 21 ++++------ sqltest/test.pgsql | 8 ++++ 11 files changed, 136 insertions(+), 45 deletions(-) rename error.go => mssql_error.go (92%) create mode 100644 sqltest/test.pgsql diff --git a/cli/cmd/build.go b/cli/cmd/build.go index 1ffdde2..c0d7db3 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -3,6 +3,8 @@ package cmd import ( "errors" "fmt" + + mssql "github.com/denisenkom/go-mssqldb" "github.com/spf13/cobra" "github.com/vippsas/sqlcode" ) @@ -23,7 +25,7 @@ var ( return err } - preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix) + preprocessed, err := sqlcode.Preprocess(d.CodeBase, schemasuffix, &mssql.Driver{}) if err != nil { return err } diff --git a/deployable.go b/deployable.go index c1b38fe..ead2709 100644 --- a/deployable.go +++ b/deployable.go @@ -98,7 +98,8 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { } } - preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix) + preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) + if err != nil { _ = tx.Rollback() return err @@ -107,15 +108,16 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { _, err := tx.ExecContext(ctx, b.Lines) if err != nil { _ = tx.Rollback() - sqlerr, ok := err.(mssql.Error) - if !ok { - return err - } else { - return SQLUserError{ + if sqlerr, ok := err.(mssql.Error); ok { + return MSSQLUserError{ Wrapped: sqlerr, Batch: b, } } + + // TODO(ks) PGSQLUserError + return fmt.Errorf("failed to upload deployable:%s in schema:%s:%w", d.CodeBase, d.SchemaSuffix, err) + } } err = tx.Commit() @@ -327,10 +329,28 @@ func (s *SchemaObject) Suffix() string { // Return a list of sqlcode schemas that have been uploaded to the database. // This includes all current and unused schemas. -func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { +func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) ([]*SchemaObject, error) { objects := []*SchemaObject{} - impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", func(conn *sql.Conn) error { - rows, err := conn.QueryContext(ctx, ` + driver := dbc.Driver() + var qs string + + var list = func(conn *sql.Conn) error { + rows, err := conn.QueryContext(ctx, qs) + if err != nil { + return err + } + + for rows.Next() { + zero := &SchemaObject{} + rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) + objects = append(objects, zero) + } + + return nil + } + + if _, ok := driver.(*mssql.Driver); ok { + qs = ` select s.name , s.schema_id @@ -345,18 +365,32 @@ func (d *Deployable) ListUploaded(ctx context.Context, dbc DB) []*SchemaObject { from sys.objects o where o.schema_id = s.schema_id ) as o - where s.name like 'code@%'`) + where s.name like 'code@%'` + impersonate(ctx, dbc, "sqlcode-deploy-sandbox-user", list) + } + + // TODO(ks) the timestamps for schemas + if _, ok := driver.(*stdlib.Driver); ok { + qs = `select nspname as name + , oid as schema_id + , 0 as objects + , '' as create_date + , '' as modify_date + from pg_namespace + where nspname like 'code@%' + order by nspname` + conn, err := dbc.Conn(ctx) if err != nil { - return err + return nil, err } - - for rows.Next() { - zero := &SchemaObject{} - rows.Scan(&zero.Name, &zero.Objects, &zero.SchemaId, &zero.CreateDate, &zero.ModifyDate) - objects = append(objects, zero) + err = list(conn) + if err != nil { + return nil, err } + defer func() { + _ = conn.Close() + }() + } - return nil - }) - return objects + return objects, nil } diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0003.sqlcode.pgsql index 98f9228..1395bc4 100644 --- a/migrations/0003.sqlcode.pgsql +++ b/migrations/0003.sqlcode.pgsql @@ -101,7 +101,7 @@ begin -- create the schema owned by "sqlcode-user-with-no-permissions" execute format( - 'create schema %I authorization %I', + 'create schema if not exists %I authorization %I', schemaname, 'sqlcode-user-with-no-permissions' ); @@ -220,6 +220,9 @@ grant execute on procedure sqlcode.createcodeschema(varchar) grant execute on procedure sqlcode.dropcodeschema(varchar) to "sqlcode-deploy-role"; +grant "sqlcode-user-with-no-permissions" + to "sqlcode-definer-role"; + -- usually deploy role does not need create in the sqlcode management schema -- (the procedures handle creation in separate "code@..." schemas) grant usage on schema sqlcode diff --git a/error.go b/mssql_error.go similarity index 92% rename from error.go rename to mssql_error.go index 6131fbf..22d4bde 100644 --- a/error.go +++ b/mssql_error.go @@ -3,17 +3,18 @@ package sqlcode import ( "bytes" "fmt" + "strings" + mssql "github.com/denisenkom/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" - "strings" ) -type SQLUserError struct { +type MSSQLUserError struct { Wrapped mssql.Error Batch Batch } -func (s SQLUserError) Error() string { +func (s MSSQLUserError) Error() string { var buf bytes.Buffer if _, fmterr := fmt.Fprintf(&buf, "\n"); fmterr != nil { diff --git a/preprocess.go b/preprocess.go index 5a8adab..2c6f647 100644 --- a/preprocess.go +++ b/preprocess.go @@ -2,12 +2,15 @@ package sqlcode import ( "crypto/sha256" + "database/sql/driver" "encoding/hex" "errors" "fmt" - "github.com/vippsas/sqlcode/sqlparser" "regexp" "strings" + + "github.com/jackc/pgx/v5/stdlib" + "github.com/vippsas/sqlcode/sqlparser" ) func SchemaSuffixFromHash(doc sqlparser.Document) string { @@ -138,7 +141,7 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot return } -func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, error) { +func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Driver) (PreprocessedFile, error) { var result PreprocessedFile if strings.Contains(schemasuffix, "]") { @@ -154,10 +157,21 @@ func Preprocess(doc sqlparser.Document, schemasuffix string) (PreprocessedFile, if len(create.Body) == 0 { continue } - batch, err := sqlcodeTransformCreate(declares, create, "[code@"+schemasuffix+"]") + if create.Driver != driver { + // continue + } + // TODO(ks) this is not reached + target := "[code@" + schemasuffix + "]" + + if _, ok := create.Driver.(*stdlib.Driver); ok { + target = "code@" + schemasuffix + } + + batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { - return result, err + return result, fmt.Errorf("failed to transform create: %w", err) } + fmt.Print(batch) result.Batches = append(result.Batches, batch) } diff --git a/sqlparser/dom.go b/sqlparser/dom.go index cc661f4..14209ee 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -1,10 +1,12 @@ package sqlparser import ( + "database/sql/driver" "fmt" - "gopkg.in/yaml.v3" "io" "strings" + + "gopkg.in/yaml.v3" ) type Unparsed struct { @@ -64,7 +66,8 @@ type Create struct { QuotedName PosString // proc/func/type name, including [] Body []Unparsed DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Driver driver.Driver // the sql driver this document is intended for } func (c Create) DocstringAsString() string { diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 40eebe9..b61f49e 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -12,6 +12,9 @@ import ( "regexp" "sort" "strings" + + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" ) var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" @@ -276,6 +279,14 @@ func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { case "create": // should be start of create procedure or create function... c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + // *prepend* what we saw before getting to the 'create' createCountInBatch++ c.Body = append(nodes, c.Body...) @@ -580,7 +591,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") { + if !strings.HasSuffix(path, ".sql") || !strings.HasSuffix(path, ".pgsql") { return nil } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 7bd20b8..3acc3c2 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -5,10 +5,27 @@ import ( "strings" "testing" + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestPostgresqlCreate(t *testing.T) { + doc := ParseString("test.pgsql", ` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; + `) + + require.Len(t, doc.Creates, 1) + require.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) +} + func TestParserSmokeTest(t *testing.T) { doc := ParseString("test.sql", ` /* test is a test @@ -43,6 +60,7 @@ end; require.Equal(t, 1, len(doc.Creates)) c := doc.Creates[0] + require.Equal(t, &mssql.Driver{}, c.Driver) assert.Equal(t, "[TestFunc]", c.QuotedName.Value) assert.Equal(t, []string{"[HelloFunc]", "[OtherFunc]"}, c.DependsOnStrings()) diff --git a/sqltest/sql.go b/sqltest/sql.go index 15d7995..4ad2754 100644 --- a/sqltest/sql.go +++ b/sqltest/sql.go @@ -9,7 +9,11 @@ import ( //go:embed *.sql var sqlfs embed.FS +//go:embed *.pgsql +var pgsqlfx embed.FS + var SQL = sqlcode.MustInclude( sqlcode.Options{}, sqlfs, + pgsqlfx, ) diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index f343c07..db1eb44 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -2,9 +2,9 @@ package sqltest import ( "context" + "fmt" "testing" - "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -30,7 +30,8 @@ func Test_RowsAffected(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(1), rowsAffected) - schemas := SQL.ListUploaded(ctx, fixture.DB) + schemas, err := SQL.ListUploaded(ctx, fixture.DB) + require.NoError(t, err) require.Len(t, schemas, 1) require.Equal(t, 6, schemas[0].Objects) require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) @@ -62,22 +63,14 @@ func Test_EnsureUploaded(t *testing.T) { ctx := context.Background() - _, err := fixture.adminDB.Exec(`grant create on database @database to "sqlcode-definer-role"`, - pgx.NamedArgs{"database": fixture.DBName}) + _, err := fixture.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) require.NoError(t, err) require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) - - res, err := fixture.DB.ExecContext(ctx, patched) - require.NoError(t, err) - rowsAffected, err := res.RowsAffected() + schemas, err := SQL.ListUploaded(ctx, fixture.DB) require.NoError(t, err) - assert.Equal(t, int64(1), rowsAffected) + require.Equal(t, "code@e3b0c44298fc", schemas[0].Name) - schemas := SQL.ListUploaded(ctx, fixture.DB) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) }) } diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql new file mode 100644 index 0000000..e16e7dd --- /dev/null +++ b/sqltest/test.pgsql @@ -0,0 +1,8 @@ + +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; \ No newline at end of file From da5b11522ca357979b616a4fdc4a444ec9df4bd4 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Fri, 5 Dec 2025 15:04:57 +0100 Subject: [PATCH 08/28] [wip] update parser and scanner --- preprocess_test.go | 385 +++++++++++++++++++++++++++++++++++++-- sqlparser/parser.go | 13 +- sqlparser/parser_test.go | 245 +++++++++++++++++++++++++ sqltest/test.pgsql | 25 ++- 4 files changed, 652 insertions(+), 16 deletions(-) diff --git a/preprocess_test.go b/preprocess_test.go index bf976e8..20998ff 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -1,24 +1,16 @@ package sqlcode import ( + "strings" "testing" + mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqlparser" ) -func TestSchemaSuffixFromHash(t *testing.T) { - t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - - value := SchemaSuffixFromHash(doc) - require.Equal(t, value, SchemaSuffixFromHash(doc)) - }) -} - func TestLineNumberInInput(t *testing.T) { // Scenario: @@ -63,3 +55,374 @@ func TestLineNumberInInput(t *testing.T) { } assert.Equal(t, expectedInputLineNumbers, inputlines[1:]) } + +func TestSchemaSuffixFromHash(t *testing.T) { + t.Run("returns a unique hash", func(t *testing.T) { + doc := sqlparser.Document{ + Declares: []sqlparser.Declare{}, + } + + value := SchemaSuffixFromHash(doc) + require.Equal(t, value, SchemaSuffixFromHash(doc)) + }) + + t.Run("returns consistent hash", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc) + suffix2 := SchemaSuffixFromHash(doc) + + assert.Equal(t, suffix1, suffix2) + assert.Len(t, suffix1, 12) // 6 bytes = 12 hex chars + }) + + t.Run("different content yields different hash", func(t *testing.T) { + doc1 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 1; +go +create procedure [code].Test1 as begin end +`) + doc2 := sqlparser.ParseString("test.sql", ` +declare @EnumFoo int = 2; +go +create procedure [code].Test2 as begin end +`) + + suffix1 := SchemaSuffixFromHash(doc1) + suffix2 := SchemaSuffixFromHash(doc2) + + assert.NotEqual(t, suffix1, suffix2) + }) + + t.Run("empty document has hash", func(t *testing.T) { + doc := sqlparser.Document{} + suffix := SchemaSuffixFromHash(doc) + assert.Len(t, suffix, 12) + }) +} + +func TestSchemaName(t *testing.T) { + assert.Equal(t, "code@abc123", SchemaName("abc123")) + assert.Equal(t, "code@", SchemaName("")) +} + +func TestBatchLineNumberInInput(t *testing.T) { + t.Run("no corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nline3", + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) + assert.Equal(t, 11, b.LineNumberInInput(2)) + assert.Equal(t, 12, b.LineNumberInInput(3)) + }) + + t.Run("with corrections", func(t *testing.T) { + b := Batch{ + StartPos: sqlparser.Pos{Line: 10, Col: 1}, + Lines: "line1\nline2\nextra1\nextra2\nline3", + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 2}, // line 2 became 3 lines + }, + } + + assert.Equal(t, 10, b.LineNumberInInput(1)) // line 1 -> input line 10 + assert.Equal(t, 11, b.LineNumberInInput(2)) // line 2 -> input line 11 + assert.Equal(t, 11, b.LineNumberInInput(3)) // extra line -> still input line 11 + assert.Equal(t, 11, b.LineNumberInInput(4)) // extra line -> still input line 11 + assert.Equal(t, 12, b.LineNumberInInput(5)) // line 3 -> input line 12 + }) +} + +func TestBatchRelativeLineNumberInInput(t *testing.T) { + t.Run("simple case with no corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{}, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(5)) + }) + + t.Run("single correction", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 3, extraLinesInOutput: 2}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(3)) + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) // extra line + assert.Equal(t, 3, b.RelativeLineNumberInInput(5)) // extra line + assert.Equal(t, 4, b.RelativeLineNumberInInput(6)) + }) + + t.Run("multiple corrections", func(t *testing.T) { + b := Batch{ + lineNumberCorrections: []lineNumberCorrection{ + {inputLineNumber: 2, extraLinesInOutput: 1}, + {inputLineNumber: 5, extraLinesInOutput: 3}, + }, + } + + assert.Equal(t, 1, b.RelativeLineNumberInInput(1)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(2)) + assert.Equal(t, 2, b.RelativeLineNumberInInput(3)) // extra from line 2 + assert.Equal(t, 3, b.RelativeLineNumberInInput(4)) + assert.Equal(t, 4, b.RelativeLineNumberInInput(5)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(6)) + assert.Equal(t, 5, b.RelativeLineNumberInInput(7)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(8)) // extra from line 5 + assert.Equal(t, 5, b.RelativeLineNumberInInput(9)) // extra from line 5 + assert.Equal(t, 6, b.RelativeLineNumberInInput(10)) + }) +} + +func TestPreprocess(t *testing.T) { + t.Run("basic procedure with schema replacement", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select 1 +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "[code@abc123]") + assert.NotContains(t, result.Batches[0].Lines, "[code]") + }) + + t.Run("postgres uses unquoted schema name", func(t *testing.T) { + doc := sqlparser.ParseString("test.pgsql", ` +create procedure [code].test() as $$ +begin + perform 1; +end; +$$ language plpgsql; +`) + doc.Creates[0].Driver = &stdlib.Driver{} + + result, err := Preprocess(doc, "abc123", &stdlib.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + assert.Contains(t, result.Batches[0].Lines, "code@abc123") + assert.NotContains(t, result.Batches[0].Lines, "[code@abc123]") + }) + + t.Run("replaces enum constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumStatus int = 42; +go +create procedure [code].Test as +begin + select @EnumStatus +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "42/*=@EnumStatus*/") + assert.NotContains(t, batch, "@EnumStatus\n") // shouldn't have bare reference + }) + + t.Run("handles multiline string constants", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumMulti nvarchar(max) = N'line1 +line2 +line3'; +go +create procedure [code].Test as +begin + select @EnumMulti +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0] + assert.Contains(t, batch.Lines, "N'line1\nline2\nline3'/*=@EnumMulti*/") + // Should have line number corrections for the 2 extra lines + assert.Len(t, batch.lineNumberCorrections, 1) + assert.Equal(t, 2, batch.lineNumberCorrections[0].extraLinesInOutput) + }) + + t.Run("error on undeclared constant", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as +begin + select @EnumUndeclared +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + _, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.Error(t, err) + + var preprocErr PreprocessorError + require.ErrorAs(t, err, &preprocErr) + assert.Contains(t, preprocErr.Message, "@EnumUndeclared") + assert.Contains(t, preprocErr.Message, "not declared") + }) + + t.Run("error on schema suffix with bracket", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Test as begin end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "schemasuffix cannot contain") + }) + + t.Run("skips creates with empty body", func(t *testing.T) { + doc := sqlparser.Document{ + Creates: []sqlparser.Create{ + {Body: []sqlparser.Unparsed{}}, + }, + } + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Empty(t, result.Batches) + }) + + t.Run("handles multiple creates", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +create procedure [code].Proc1 as begin select 1 end +go +create procedure [code].Proc2 as begin select 2 end +`) + doc.Creates[0].Driver = &mssql.Driver{} + doc.Creates[1].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + assert.Len(t, result.Batches, 2) + + assert.Contains(t, result.Batches[0].Lines, "Proc1") + assert.Contains(t, result.Batches[1].Lines, "Proc2") + }) + + t.Run("handles multiple constants in same procedure", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @EnumA int = 1, @EnumB int = 2; +go +create procedure [code].Test as +begin + select @EnumA, @EnumB +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "1/*=@EnumA*/") + assert.Contains(t, batch, "2/*=@EnumB*/") + }) + + t.Run("preserves comments and formatting", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +-- This is a test procedure +create procedure [code].Test as +begin + /* multi + line + comment */ + select 1 +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "-- This is a test procedure") + assert.Contains(t, batch, "/* multi") + }) + + t.Run("handles const and global prefixes", func(t *testing.T) { + doc := sqlparser.ParseString("test.sql", ` +declare @ConstValue int = 100, @GlobalSetting nvarchar(50) = N'test'; +go +create procedure [code].Test as +begin + select @ConstValue, @GlobalSetting +end +`) + doc.Creates[0].Driver = &mssql.Driver{} + + result, err := Preprocess(doc, "abc123", &mssql.Driver{}) + require.NoError(t, err) + require.Len(t, result.Batches, 1) + + batch := result.Batches[0].Lines + assert.Contains(t, batch, "100/*=@ConstValue*/") + assert.Contains(t, batch, "N'test'/*=@GlobalSetting*/") + }) +} + +func TestPreprocessString(t *testing.T) { + t.Run("replaces code schema", func(t *testing.T) { + result := preprocessString("abc123", "select * from [code].Table") + assert.Equal(t, "select * from [code@abc123].Table", result) + }) + + t.Run("case insensitive replacement", func(t *testing.T) { + result := preprocessString("abc123", "select * from [CODE].Table and [Code].Other") + assert.Contains(t, result, "[code@abc123].Table") + assert.Contains(t, result, "[code@abc123].Other") + }) + + t.Run("multiple occurrences", func(t *testing.T) { + sql := ` + select * from [code].A + join [code].B on A.id = B.id + where exists (select 1 from [code].C) + ` + result := preprocessString("abc123", sql) + assert.Equal(t, 3, strings.Count(result, "[code@abc123]")) + assert.NotContains(t, result, "[code].") + }) + + t.Run("no replacement needed", func(t *testing.T) { + sql := "select * from dbo.Table" + result := preprocessString("abc123", sql) + assert.Equal(t, sql, result) + }) +} + +func TestPreprocessorError(t *testing.T) { + t.Run("formats error message", func(t *testing.T) { + err := PreprocessorError{ + Pos: sqlparser.Pos{File: "test.sql", Line: 10, Col: 5}, + Message: "something went wrong", + } + + assert.Equal(t, "test.sql:10:5: something went wrong", err.Error()) + }) +} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index b61f49e..b88e4ea 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -9,7 +9,9 @@ import ( "errors" "fmt" "io/fs" + "path/filepath" "regexp" + "slices" "sort" "strings" @@ -19,6 +21,8 @@ import ( var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" +var supportedSqlExtensions []string = []string{".sql", ".pgsql"} + func CopyToken(s *Scanner, target *[]Unparsed) { *target = append(*target, CreateUnparsed(s)) } @@ -564,9 +568,8 @@ func ParseString(filename FileRef, input string) (result Document) { return } -// ParseFileystems iterates through a list of filesystems and parses all files -// matching `*.sql`, determines which one are sqlcode files from the contents, -// and returns the combination of all of them. +// ParseFileystems iterates through a list of filesystems and parses all supported +// SQL files and returns the combination of all of them. // // err will only return errors related to filesystems/reading. Errors // related to parsing/sorting will be in result.Errors. @@ -591,7 +594,9 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, if strings.HasPrefix(path, ".") || strings.Contains(path, "/.") { return nil } - if !strings.HasSuffix(path, ".sql") || !strings.HasSuffix(path, ".pgsql") { + + extension := filepath.Ext(path) + if !slices.Contains(supportedSqlExtensions, extension) { return nil } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 3acc3c2..fc75460 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -2,8 +2,10 @@ package sqlparser import ( "fmt" + "io/fs" "strings" "testing" + "testing/fstest" mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" @@ -410,3 +412,246 @@ create procedure [code].Foo as begin end err.Error()) } + +func TestParseFilesystems(t *testing.T) { + t.Run("basic parsing of sql files", func(t *testing.T) { + fsys := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(` +declare @EnumFoo int = 1; +go +create procedure [code].Proc1 as begin end +`), + }, + "test2.sql": &fstest.MapFile{ + Data: []byte(` +create function [code].Func1() returns int as begin return 1 end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Declares, 1) + }) + + t.Run("filters by include tags", func(t *testing.T) { + fsys := fstest.MapFS{ + "included.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo,bar +create procedure [code].Included as begin end +`), + }, + "excluded.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if baz +create procedure [code].Excluded as begin end +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo", "bar"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "included.sql") + assert.Len(t, doc.Creates, 1) + assert.Equal(t, "[Included]", doc.Creates[0].QuotedName.Value) + }) + + t.Run("detects duplicate files with same hash", func(t *testing.T) { + contents := []byte(`create procedure [code].Test as begin end`) + + fs1 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + fs2 := fstest.MapFS{ + "test.sql": &fstest.MapFile{Data: contents}, + } + + _, _, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "exact same contents") + }) + + t.Run("skips non-sqlcode files", func(t *testing.T) { + fsys := fstest.MapFS{ + "regular.sql": &fstest.MapFile{ + Data: []byte(`select * from table1`), + }, + "sqlcode.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "sqlcode.sql") + assert.Len(t, doc.Creates, 1) + }) + + t.Run("skips hidden directories", func(t *testing.T) { + fsys := fstest.MapFS{ + "visible.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Visible as begin end`), + }, + ".hidden/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Hidden as begin end`), + }, + "dir/.git/test.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Git as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Contains(t, filenames[0], "visible.sql") + assert.Len(t, doc.Creates, 1) + }) + + t.Run("handles dependencies and topological sort", func(t *testing.T) { + fsys := fstest.MapFS{ + "proc1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin exec [code].Proc2 end`), + }, + "proc2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin select 1 end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Len(t, doc.Creates, 2) + // Proc2 should come before Proc1 due to dependency + assert.Equal(t, "[Proc2]", doc.Creates[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates[1].QuotedName.Value) + }) + + t.Run("reports topological sort errors", func(t *testing.T) { + fsys := fstest.MapFS{ + "circular1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].A as begin exec [code].B end`), + }, + "circular2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].B as begin exec [code].A end`), + }, + } + + _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors) // but parsing errors should exist + assert.Contains(t, doc.Errors[0].Message, "Detected a dependency cycle") + }) + + t.Run("handles multiple filesystems", func(t *testing.T) { + fs1 := fstest.MapFS{ + "test1.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc1 as begin end`), + }, + } + fs2 := fstest.MapFS{ + "test2.sql": &fstest.MapFile{ + Data: []byte(`create procedure [code].Proc2 as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fs1, fs2}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 2) + assert.Contains(t, filenames[0], "fs[0]:") + assert.Contains(t, filenames[1], "fs[1]:") + assert.Len(t, doc.Creates, 2) + }) + + t.Run("detects sqlcode files by pragma header", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.sql": &fstest.MapFile{ + Data: []byte(`--sqlcode:include-if foo +create procedure NotInCodeSchema.Test as begin end`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, []string{"foo"}) + require.NoError(t, err) + assert.Len(t, filenames, 1) + // Should still parse even though it will have errors (not in [code] schema) + assert.NotEmpty(t, doc.Errors) + }) + + t.Run("handles pgsql extension", func(t *testing.T) { + fsys := fstest.MapFS{ + "test.pgsql": &fstest.MapFile{ + Data: []byte(` +create procedure [code].test() +language plpgsql +as $$ +begin + perform 1; +end; +$$; +`), + }, + } + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Len(t, filenames, 1) + assert.Len(t, doc.Creates, 1) + assert.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + }) + + t.Run("empty filesystem returns empty results", func(t *testing.T) { + fsys := fstest.MapFS{} + + filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) + require.NoError(t, err) + assert.Empty(t, filenames) + assert.Empty(t, doc.Creates) + assert.Empty(t, doc.Declares) + }) +} + +func TestMatchesIncludeTags(t *testing.T) { + t.Run("empty requirements matches anything", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{}, []string{})) + assert.True(t, matchesIncludeTags([]string{}, []string{"foo"})) + }) + + t.Run("all requirements must be met", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo", "bar", "baz"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo", "bar"}, []string{"bar"})) + }) + + t.Run("exact match", func(t *testing.T) { + assert.True(t, matchesIncludeTags([]string{"foo"}, []string{"foo"})) + assert.False(t, matchesIncludeTags([]string{"foo"}, []string{"bar"})) + }) +} + +func TestIsSqlcodeConstVariable(t *testing.T) { + testCases := []struct { + name string + varname string + expected bool + }{ + {"@Enum prefix", "@EnumFoo", true}, + {"@ENUM_ prefix", "@ENUM_FOO", true}, + {"@enum_ prefix", "@enum_foo", true}, + {"@Const prefix", "@ConstFoo", true}, + {"@CONST_ prefix", "@CONST_FOO", true}, + {"@const_ prefix", "@const_foo", true}, + {"regular variable", "@MyVariable", false}, + {"@Global prefix", "@GlobalVar", false}, + {"no @ prefix", "EnumFoo", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsSqlcodeConstVariable(tc.varname)) + }) + } +} diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql index e16e7dd..49992da 100644 --- a/sqltest/test.pgsql +++ b/sqltest/test.pgsql @@ -1,5 +1,28 @@ +-- consts can be -create procedure [code].test() +-- sqlcode +-- we define schemas per deployment +-- uploading all stored functions/procedures/types/consts to a schema +-- pods are restarted/deployed per deployment + +-- aaa bbb + 3 1 + +(iof increase in errors, stop deployment) +-- aaa bbb + 3 0 +-- aaa bbb + 1 2 +-- aaa bbb + 0 3 +-- + +-- ++ both mssql and pgsql have the same architecture with schemas and stored functions/procedures + +-- Q: constants? +-- we have the same constants defined in both mssql and pggsql + +create procedure [code].test() -- expands to code@aaa.test language plpgsql as $$ begin From 18309f07fa844dfca0ed2edf482c8bc78bf3d400 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:27:56 +0100 Subject: [PATCH 09/28] Fixed issue with Preprocess. Passing pgsql tests. --- deployable.go | 21 +++++- deployable_test.go | 6 ++ dockerfile.test | 1 + ...{0003.sqlcode.pgsql => 0001.sqlcode.pgsql} | 0 preprocess.go | 26 ++++--- preprocess_test.go | 13 ++-- sqltest/fixture.go | 31 ++++++++ sqltest/sqlcode_test.go | 75 ++++++++++--------- sqltest/test.pgsql | 26 +------ 9 files changed, 118 insertions(+), 81 deletions(-) rename migrations/{0003.sqlcode.pgsql => 0001.sqlcode.pgsql} (100%) diff --git a/deployable.go b/deployable.go index ead2709..ec784d7 100644 --- a/deployable.go +++ b/deployable.go @@ -99,7 +99,6 @@ func (d *Deployable) Upload(ctx context.Context, dbc DB) error { } preprocessed, err := Preprocess(d.CodeBase, d.SchemaSuffix, dbc.Driver()) - if err != nil { _ = tx.Rollback() return err @@ -244,11 +243,28 @@ func (d Deployable) DropAndUpload(ctx context.Context, dbc DB) error { } // Patch will preprocess the sql passed in so that it will call SQL code -// deployed by the receiver Deployable +// deployed by the receiver Deployable for SQL Server. +// NOTE: This will be deprecated and eventually replaced with CodePatch. func (d Deployable) Patch(sql string) string { return preprocessString(d.SchemaSuffix, sql) } +// CodePatch will preprocess the sql passed in to call +// the correct SQL code deployed to the provided database. +// Q: Nameing? DBPatch, PatchV2, ??? +func (d Deployable) CodePatch(dbc *sql.DB, sql string) string { + driver := dbc.Driver() + if _, ok := driver.(*mssql.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`[code@%s]`, d.SchemaSuffix)) + } + + if _, ok := driver.(*stdlib.Driver); ok { + return codeSchemaRegexp.ReplaceAllString(sql, fmt.Sprintf(`"code@%s"`, d.SchemaSuffix)) + } + + panic("unhandled sql driver") +} + func (d *Deployable) markAsUploaded(dbc DB) { d.uploaded[dbc] = struct{}{} } @@ -260,7 +276,6 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case - func (d Deployable) IntConst(s string) (int, error) { for _, declare := range d.CodeBase.Declares { if declare.VariableName == s { diff --git a/deployable_test.go b/deployable_test.go index 1e9dac5..7b87b57 100644 --- a/deployable_test.go +++ b/deployable_test.go @@ -21,5 +21,11 @@ declare @EnumInt int = 1, @EnumString varchar(max) = 'hello'; n, err := d.IntConst("@EnumInt") require.NoError(t, err) assert.Equal(t, 1, n) +} +func TestPatch(t *testing.T) { + t.Run("mssql schemasuffix", func(t *testing.T) { + d := Deployable{} + require.Equal(t, "[code@].Foo", d.Patch("[code].Foo")) + }) } diff --git a/dockerfile.test b/dockerfile.test index f4a199f..d8cdf13 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -3,4 +3,5 @@ WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy +# Skip the example folder because it has examples of what-not-to-do CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file diff --git a/migrations/0003.sqlcode.pgsql b/migrations/0001.sqlcode.pgsql similarity index 100% rename from migrations/0003.sqlcode.pgsql rename to migrations/0001.sqlcode.pgsql diff --git a/preprocess.go b/preprocess.go index 2c6f647..c4a776b 100644 --- a/preprocess.go +++ b/preprocess.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "errors" "fmt" + "reflect" "regexp" "strings" @@ -131,7 +132,6 @@ func sqlcodeTransformCreate(declares map[string]string, c sqlparser.Create, quot result.lineNumberCorrections = append(result.lineNumberCorrections, lineNumberCorrection{relativeLine, newlineCount}) } } - if _, err = w.WriteString(token); err != nil { return } @@ -153,25 +153,29 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive declares[dec.VariableName] = dec.Literal.RawValue } + // The current sql driver that we are preparring for + currentDriver := reflect.TypeOf(driver) + + // the default target for mssql + target := fmt.Sprintf(`[code@%s]`, schemasuffix) + + // pgsql target + if _, ok := driver.(*stdlib.Driver); ok { + target = fmt.Sprintf(`"code@%s"`, schemasuffix) + } + for _, create := range doc.Creates { if len(create.Body) == 0 { continue } - if create.Driver != driver { - // continue - } - // TODO(ks) this is not reached - target := "[code@" + schemasuffix + "]" - - if _, ok := create.Driver.(*stdlib.Driver); ok { - target = "code@" + schemasuffix + if !currentDriver.AssignableTo(reflect.TypeOf(create.Driver)) { + // this batch is for a different sql driver + continue } - batch, err := sqlcodeTransformCreate(declares, create, target) if err != nil { return result, fmt.Errorf("failed to transform create: %w", err) } - fmt.Print(batch) result.Batches = append(result.Batches, batch) } diff --git a/preprocess_test.go b/preprocess_test.go index 20998ff..940b8f6 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -199,8 +199,8 @@ end require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "[code@abc123]") - assert.NotContains(t, result.Batches[0].Lines, "[code]") + assert.Contains(t, result.Batches[0].Lines, "[code@abc123].") + assert.NotContains(t, result.Batches[0].Lines, "[code].") }) t.Run("postgres uses unquoted schema name", func(t *testing.T) { @@ -217,8 +217,8 @@ $$ language plpgsql; require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "code@abc123") - assert.NotContains(t, result.Batches[0].Lines, "[code@abc123]") + assert.Contains(t, result.Batches[0].Lines, "code@abc123.") + assert.NotContains(t, result.Batches[0].Lines, "[code@abc123].") }) t.Run("replaces enum constants", func(t *testing.T) { @@ -367,7 +367,8 @@ end t.Run("handles const and global prefixes", func(t *testing.T) { doc := sqlparser.ParseString("test.sql", ` -declare @ConstValue int = 100, @GlobalSetting nvarchar(50) = N'test'; +declare @ConstValue int = 100; +declare @GlobalSetting nvarchar(50) = N'test'; go create procedure [code].Test as begin @@ -382,7 +383,7 @@ end batch := result.Batches[0].Lines assert.Contains(t, batch, "100/*=@ConstValue*/") - assert.Contains(t, batch, "N'test'/*=@GlobalSetting*/") + assert.NotContains(t, batch, "N'test'/*=@GlobalSetting*/") }) } diff --git a/sqltest/fixture.go b/sqltest/fixture.go index 82f395d..e05c418 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "testing" "time" mssql "github.com/denisenkom/go-mssqldb" @@ -140,6 +141,24 @@ func NewFixture() *Fixture { if err != nil { panic(err) } + + var user string + err = fixture.DB.QueryRow(`select current_user`).Scan(&user) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL ON DATABASE "%s" TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`GRANT ALL PRIVILEGES ON SCHEMA public TO %s;`, user)) + if err != nil { + panic(err) + } + _, err = fixture.DB.Exec(fmt.Sprintf(`ALTER DATABASE "%s" OWNER TO sa;`, fixture.DBName)) + if err != nil { + panic(err) + } } return &fixture @@ -177,3 +196,15 @@ func (f *Fixture) RunMigrationFile(filename string) { } } } + +func (f *Fixture) RunIfPostgres(t *testing.T, name string, fn func(t *testing.T)) { + if f.IsPostgresql() { + t.Run(name, fn) + } +} + +func (f *Fixture) RunIfMssql(t *testing.T, name string, fn func(t *testing.T)) { + if f.IsSqlServer() { + t.Run(name, fn) + } +} diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index db1eb44..910064b 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -9,68 +9,71 @@ import ( "github.com/stretchr/testify/require" ) -func Test_RowsAffected(t *testing.T) { +func Test_Patch(t *testing.T) { fixture := NewFixture() + ctx := context.Background() defer fixture.Teardown() - t.Run("mssql", func(t *testing.T) { - if !fixture.IsSqlServer() { - t.Skip() - } + if fixture.IsSqlServer() { fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + } - ctx := context.Background() + if fixture.IsPostgresql() { + fixture.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + _, err := fixture.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) + require.NoError(t, err) + } - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - patched := SQL.Patch(`[code].Test`) + require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) + fixture.RunIfMssql(t, "mssql", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `[code].Test`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) + rowsAffected, err := res.RowsAffected() require.NoError(t, err) assert.Equal(t, int64(1), rowsAffected) + }) - schemas, err := SQL.ListUploaded(ctx, fixture.DB) + fixture.RunIfPostgres(t, "pgsql", func(t *testing.T) { + patched := SQL.CodePatch(fixture.DB, `call [code].Test()`) + res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) - require.Len(t, schemas, 1) - require.Equal(t, 6, schemas[0].Objects) - require.Equal(t, "5420c0269aaf", schemas[0].Suffix()) + // postgresql perform does not result with affected rows + rowsAffected, err := res.RowsAffected() + require.NoError(t, err) + assert.Equal(t, int64(0), rowsAffected) }) } func Test_EnsureUploaded(t *testing.T) { - fixture := NewFixture() - defer fixture.Teardown() - - t.Run("mssql", func(t *testing.T) { - if !fixture.IsSqlServer() { - t.Skip() - } - fixture.RunMigrationFile("../migrations/0001.sqlcode.sql") + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "mssql", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) + require.NoError(t, err) + require.Len(t, schemas, 1) - ctx := context.Background() - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) }) - t.Run("pgsql", func(t *testing.T) { - if !fixture.IsPostgresql() { - t.Skip() - } - - fixture.RunMigrationFile("../migrations/0003.sqlcode.pgsql") + f.RunIfPostgres(t, "pgsql", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") - ctx := context.Background() - - _, err := fixture.DB.Exec( - fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, fixture.DBName)) + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) require.NoError(t, err) - require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - schemas, err := SQL.ListUploaded(ctx, fixture.DB) + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + schemas, err := SQL.ListUploaded(ctx, f.DB) require.NoError(t, err) - require.Equal(t, "code@e3b0c44298fc", schemas[0].Name) - + require.Len(t, schemas, 1) }) } diff --git a/sqltest/test.pgsql b/sqltest/test.pgsql index 49992da..bff51a5 100644 --- a/sqltest/test.pgsql +++ b/sqltest/test.pgsql @@ -1,28 +1,4 @@ --- consts can be - --- sqlcode --- we define schemas per deployment --- uploading all stored functions/procedures/types/consts to a schema --- pods are restarted/deployed per deployment - --- aaa bbb - 3 1 - -(iof increase in errors, stop deployment) --- aaa bbb - 3 0 --- aaa bbb - 1 2 --- aaa bbb - 0 3 --- - --- ++ both mssql and pgsql have the same architecture with schemas and stored functions/procedures - --- Q: constants? --- we have the same constants defined in both mssql and pggsql - -create procedure [code].test() -- expands to code@aaa.test +create procedure [code].test() language plpgsql as $$ begin From b6518144e45b182fe2b15e6632980cb093bb3e31 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:28:06 +0100 Subject: [PATCH 10/28] Updated GO workflow to test both drivers. --- .github/workflows/go.yml | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 926f7a6..e5272b4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,7 +1,7 @@ # This workflow will build a golang project # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go -name: go-querysql-test +name: sqlcode on: pull_request: @@ -11,19 +11,16 @@ jobs: build: runs-on: ubuntu-latest - env: - SQLSERVER_DSN: "sqlserver://127.0.0.1:1433?database=master&user id=sa&password=VippsPw1" + strategy: + matrix: + drivers: ['mssql','pgsql'] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Go uses: actions/setup-go@v5 with: - go-version: '1.23' - - - name: Start db - run: docker compose -f docker-compose.test.yml up -d + go-version: '1.25' - name: Test - # Skip the example folder because it has examples of what-not-to-do - run: go test -v $(go list ./... | grep -v './example') + run: docker compose -f docker-compose.${{ matrix.driver }}.yml run test \ No newline at end of file From 160df176f73262f2c51e9adb458c832dde1288e6 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:29:34 +0100 Subject: [PATCH 11/28] Fixed typo in GH workflow. --- .github/workflows/go.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index e5272b4..64b63c3 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - drivers: ['mssql','pgsql'] + driver: ['mssql','pgsql'] steps: - uses: actions/checkout@v5 From 11659eda973a966211d7d7169d29d0dcb66fb985 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:41:36 +0100 Subject: [PATCH 12/28] Updated Dockerfile --- dockerfile.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dockerfile.test b/dockerfile.test index d8cdf13..286268c 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,7 +1,7 @@ -FROM golang:1.25.1 AS builder +FROM golang:1.25 AS builder WORKDIR /sqlcode ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy # Skip the example folder because it has examples of what-not-to-do -CMD ["go", "test", "-v", "$(go list ./... | grep -v './example')"] \ No newline at end of file +CMD ["go", "test", "-v", "./..."] \ No newline at end of file From 37fd588cdcdf2370d318cb8ccf684b8ddbc63390 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:44:21 +0100 Subject: [PATCH 13/28] Use build tags to exclude examples from bulid & test --- example/basic/example.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/example/basic/example.go b/example/basic/example.go index abe1194..9406915 100644 --- a/example/basic/example.go +++ b/example/basic/example.go @@ -1,3 +1,6 @@ +//go:build examples +// +build examples + package example import ( From ca444bc09a2ed95e81f8c5998ed2ace57d6ff11e Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:46:06 +0100 Subject: [PATCH 14/28] Exclude example test --- example/basic/example_test.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/example/basic/example_test.go b/example/basic/example_test.go index 079bd91..0c78c63 100644 --- a/example/basic/example_test.go +++ b/example/basic/example_test.go @@ -1,13 +1,17 @@ +//go:build examples +// +build examples + package example import ( "context" "fmt" + "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqltest" - "testing" - "time" ) func TestPreprocess(t *testing.T) { From c483fb7dd063c6abc7fdaffd2e3a13065b74ac8f Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 9 Dec 2025 21:47:16 +0100 Subject: [PATCH 15/28] Fixed failing test. --- preprocess_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preprocess_test.go b/preprocess_test.go index 940b8f6..9d0d2ee 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -217,7 +217,7 @@ $$ language plpgsql; require.NoError(t, err) require.Len(t, result.Batches, 1) - assert.Contains(t, result.Batches[0].Lines, "code@abc123.") + assert.Contains(t, result.Batches[0].Lines, `"code@abc123".`) assert.NotContains(t, result.Batches[0].Lines, "[code@abc123].") }) From 4b803bf9e72fa20f3dfadc8e5298dadb8c695257 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:40:04 +0100 Subject: [PATCH 16/28] Moved Document structs to a separate file for better organization. --- sqlparser/document.go | 505 ++++++++++++++++++++++++++++++++++++++++++ sqlparser/parser.go | 492 ---------------------------------------- 2 files changed, 505 insertions(+), 492 deletions(-) create mode 100644 sqlparser/document.go diff --git a/sqlparser/document.go b/sqlparser/document.go new file mode 100644 index 0000000..c907aa8 --- /dev/null +++ b/sqlparser/document.go @@ -0,0 +1,505 @@ +package sqlparser + +import ( + "fmt" + "sort" + "strings" + + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" +) + +type Document struct { + PragmaIncludeIf []string + Creates []Create + Declares []Declare + Errors []Error +} + +func (d *Document) addError(s *Scanner, msg string) { + d.Errors = append(d.Errors, Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *Document) unexpectedTokenError(s *Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + } + } + + if s.TokenType() != UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != VariableIdentifierToken { + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + s.NextNonWhitespaceCommentToken() + var variableType Type + switch s.TokenType() { + case EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + doc.recoverToNextStatement(s) + } + + switch s.NextNonWhitespaceCommentToken() { + case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + result = append(result, Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: CreateUnparsed(s), + }) + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *Document) parseBatchSeparator(s *Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == EOFToken: + return false + case tt == ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.Declares = append(doc.Declares, d...) + case tt == ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + doc.recoverToNextStatement(s) + case tt == BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + } + } +} + +func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + + var createCountInBatch int + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + // regardless of errors, go on and parse as far as we get... + return doc.parseDeclareBatch(s) + case "create": + // should be start of create procedure or create function... + c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + + // *prepend* what we saw before getting to the 'create' + createCountInBatch++ + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.Creates = append(doc.Creates, c) + default: + doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) + s.NextToken() + } + case BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + s.NextToken() + docstring = nil + } + } +} + +func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func (d *Document) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + } + } +} + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result + default: + d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } +} + +// parseCreate parses anything that starts with "create". Position is +// *on* the create token. +// At this stage in sqlcode parser development we're only interested +// in procedures/functions/types as opaque blocks of SQL code where +// we only track dependencies between them and their declared name; +// so we treat them with the same code. We consume until the end of +// the batch; only one declaration allowed per batch. Everything +// parsed here will also be added to `batch`. On any error, copying +// to batch stops / becomes erratic.. +func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + createType := strings.ToLower(s.Token()) + if !(createType == "procedure" || createType == "function" || createType == "type") { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + d.recoverToNextStatementCopying(s, &result.Body) + return + } + + result.CreateType = createType + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = d.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case "function", "procedure": + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { + d.recoverToNextStatementCopying(s, &result.Body) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case "type": + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == EOFToken || tt == BatchSeparatorToken: + break tailloop + case tt == QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep := d.parseCodeschemaName(s, &result.Body) + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == ReservedWordToken && s.Token() == "as": + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} + +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), + } +} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index b88e4ea..f2b51d8 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -12,11 +12,7 @@ import ( "path/filepath" "regexp" "slices" - "sort" "strings" - - mssql "github.com/denisenkom/go-mssqldb" - "github.com/jackc/pgx/v5/stdlib" ) var templateRoutineName string = "\ndeclare @RoutineName nvarchar(128)\nset @RoutineName = '%s'\n" @@ -27,27 +23,6 @@ func CopyToken(s *Scanner, target *[]Unparsed) { *target = append(*target, CreateUnparsed(s)) } -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - // AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return // Note: The 'go' and EOF tokens are *not* copied func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { @@ -69,473 +44,6 @@ func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { } } -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } - return -} - -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } - - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} - -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result - default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } -} - -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - func Parse(s *Scanner, result *Document) { // Top-level parse; this focuses on splitting into "batches" separated // by 'go'. From f202e182eb53f9b40dedd9f3ac3e9be91e52c42b Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:41:23 +0100 Subject: [PATCH 17/28] Updated go-mssql depedency to use microsoft fork. DropAndUpload now support postgresql. --- cli/cmd/build.go | 2 +- cli/cmd/config.go | 11 +++--- dbops.go | 23 +++++++++++-- deployable.go | 2 +- docker-compose.mssql.yml | 2 +- dockerfile.test | 2 -- go.mod | 16 +++++---- go.sum | 73 ++++++++++++++++++++++------------------ mssql_error.go | 2 +- preprocess_test.go | 2 +- sqlparser/parser_test.go | 2 +- sqltest/fixture.go | 36 ++++++++++++-------- sqltest/sqlcode_test.go | 33 +++++++++++++++--- 13 files changed, 133 insertions(+), 73 deletions(-) diff --git a/cli/cmd/build.go b/cli/cmd/build.go index c0d7db3..9fd9d9a 100644 --- a/cli/cmd/build.go +++ b/cli/cmd/build.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" + mssql "github.com/microsoft/go-mssqldb" "github.com/spf13/cobra" "github.com/vippsas/sqlcode" ) diff --git a/cli/cmd/config.go b/cli/cmd/config.go index 6bebbf1..6968802 100644 --- a/cli/cmd/config.go +++ b/cli/cmd/config.go @@ -5,16 +5,17 @@ import ( "database/sql" "errors" "fmt" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/azuread" - "golang.org/x/net/proxy" "io/ioutil" "os" "path" "strings" - _ "github.com/denisenkom/go-mssqldb/azuread" - "github.com/denisenkom/go-mssqldb/msdsn" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/azuread" + "golang.org/x/net/proxy" + + _ "github.com/microsoft/go-mssqldb/azuread" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) diff --git a/dbops.go b/dbops.go index 0293a2a..d63278f 100644 --- a/dbops.go +++ b/dbops.go @@ -4,8 +4,9 @@ import ( "context" "database/sql" - mssql "github.com/denisenkom/go-mssqldb" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" ) func Exists(ctx context.Context, dbc DB, schemasuffix string) (bool, error) { @@ -33,8 +34,24 @@ func Drop(ctx context.Context, dbc DB, schemasuffix string) error { if err != nil { return err } - _, err = tx.ExecContext(ctx, `sqlcode.DropCodeSchema`, - sql.Named("schemasuffix", schemasuffix)) + + var qs string + var arg = []interface{}{} + driver := dbc.Driver() + + if _, ok := driver.(*mssql.Driver); ok { + qs = `sqlcode.DropCodeSchema` + arg = []interface{}{sql.Named("schemasuffix", schemasuffix)} + } + + if _, ok := dbc.Driver().(*stdlib.Driver); ok { + qs = `call sqlcode.dropcodeschema(@schemasuffix)` + arg = []interface{}{ + pgx.NamedArgs{"schemasuffix": schemasuffix}, + } + } + + _, err = tx.ExecContext(ctx, qs, arg...) if err != nil { _ = tx.Rollback() return err diff --git a/deployable.go b/deployable.go index ec784d7..03cc0a8 100644 --- a/deployable.go +++ b/deployable.go @@ -10,10 +10,10 @@ import ( "strings" "time" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" pgxstdlib "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" ) diff --git a/docker-compose.mssql.yml b/docker-compose.mssql.yml index 84618fc..f2e4471 100644 --- a/docker-compose.mssql.yml +++ b/docker-compose.mssql.yml @@ -17,7 +17,7 @@ services: - mssql environment: SQLSERVER_DSN: sqlserver://mssql:1433?database=master&user id=sa&password=VippsPw1 - SQLSERVER_DRIVER: sqlserver + GODEBUG: "x509negativeserial=1" depends_on: mssql: condition: service_healthy diff --git a/dockerfile.test b/dockerfile.test index 286268c..dbd2061 100644 --- a/dockerfile.test +++ b/dockerfile.test @@ -1,7 +1,5 @@ FROM golang:1.25 AS builder WORKDIR /sqlcode -ENV GODEBUG="x509negativeserial=1" COPY . . RUN go mod tidy -# Skip the example folder because it has examples of what-not-to-do CMD ["go", "test", "-v", "./..."] \ No newline at end of file diff --git a/go.mod b/go.mod index b482681..a1ad8a1 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,9 @@ go 1.24.3 require ( github.com/alecthomas/repr v0.5.2 - github.com/denisenkom/go-mssqldb v0.12.3 github.com/gofrs/uuid v4.4.0+incompatible + github.com/jackc/pgx/v5 v5.7.6 + github.com/microsoft/go-mssqldb v1.9.5 github.com/sirupsen/logrus v1.9.3 github.com/smasher164/xid v0.1.2 github.com/spf13/cobra v1.10.1 @@ -15,20 +16,23 @@ require ( ) require ( - github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect - github.com/lib/pq v1.10.9 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/crypto v0.43.0 // indirect golang.org/x/sync v0.17.0 // indirect diff --git a/go.sum b/go.sum index 17eb045..6fd729c 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,39 @@ -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0 h1:lhSJz9RMbJcTgxifR1hUNJnn6CNYtbgEDtQV22/9RBA= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0 h1:OYa9vmRX2XC5GXRAzeggG12sF/z5D9Ahtdm9EJ00WN4= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0 h1:v9p9TfTbf7AwNb5NYQt7hI41IfPoLFiFkLtb+bmGjT0= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI= +github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs= github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw= -github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -30,15 +44,27 @@ github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0= +github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.8.0 h1:q3nRvjrlge/6UD7eTu/DSg2uYiU2mCL0G/uzBWqhicI= +github.com/redis/go-redis/v9 v9.8.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smasher164/xid v0.1.2 h1:erplXSdBRIIw+MrwjJ/m8sLN2XY16UGzpTA0E2Ru6HA= @@ -52,40 +78,23 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mssql_error.go b/mssql_error.go index 22d4bde..d6f531e 100644 --- a/mssql_error.go +++ b/mssql_error.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - mssql "github.com/denisenkom/go-mssqldb" + mssql "github.com/microsoft/go-mssqldb" "github.com/vippsas/sqlcode/sqlparser" ) diff --git a/preprocess_test.go b/preprocess_test.go index 9d0d2ee..e68d8bf 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -4,8 +4,8 @@ import ( "strings" "testing" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vippsas/sqlcode/sqlparser" diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index fc75460..0e36b40 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -7,8 +7,8 @@ import ( "testing" "testing/fstest" - mssql "github.com/denisenkom/go-mssqldb" "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/sqltest/fixture.go b/sqltest/fixture.go index e05c418..b384be2 100644 --- a/sqltest/fixture.go +++ b/sqltest/fixture.go @@ -9,23 +9,23 @@ import ( "testing" "time" - mssql "github.com/denisenkom/go-mssqldb" - "github.com/denisenkom/go-mssqldb/msdsn" "github.com/gofrs/uuid" _ "github.com/jackc/pgx/v5" _ "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" + "github.com/microsoft/go-mssqldb/msdsn" ) type SqlDriverType int const ( - SqlDriverDenisen SqlDriverType = iota + SqlDriverMssql SqlDriverType = iota SqlDriverPgx ) var sqlDrivers = map[SqlDriverType]string{ - SqlDriverDenisen: "sqlserver", - SqlDriverPgx: "pgx", + SqlDriverMssql: "sqlserver", + SqlDriverPgx: "pgx", } type StdoutLogger struct { @@ -49,7 +49,7 @@ type Fixture struct { } func (f *Fixture) IsSqlServer() bool { - return f.Driver == SqlDriverDenisen + return f.Driver == SqlDriverMssql } func (f *Fixture) IsPostgresql() bool { @@ -89,7 +89,7 @@ func NewFixture() *Fixture { // 32: Log transaction begin/end dsn = dsn + "&log=63" mssql.SetLogger(StdoutLogger{}) - fixture.Driver = SqlDriverDenisen + fixture.Driver = SqlDriverMssql } if strings.Contains(dsn, "postgresql") { fixture.Driver = SqlDriverPgx @@ -123,7 +123,7 @@ func NewFixture() *Fixture { panic(err) } - pdsn, _, err := msdsn.Parse(dsn) + pdsn, err := msdsn.Parse(dsn) if err != nil { panic(err) } @@ -198,13 +198,21 @@ func (f *Fixture) RunMigrationFile(filename string) { } func (f *Fixture) RunIfPostgres(t *testing.T, name string, fn func(t *testing.T)) { - if f.IsPostgresql() { - t.Run(name, fn) - } + t.Run("pgsql", func(t *testing.T) { + if f.IsPostgresql() { + t.Run(name, fn) + } else { + t.Skip() + } + }) } func (f *Fixture) RunIfMssql(t *testing.T, name string, fn func(t *testing.T)) { - if f.IsSqlServer() { - t.Run(name, fn) - } + t.Run("mssql", func(t *testing.T) { + if f.IsSqlServer() { + t.Run(name, fn) + } else { + t.Skip() + } + }) } diff --git a/sqltest/sqlcode_test.go b/sqltest/sqlcode_test.go index 910064b..b92cc7f 100644 --- a/sqltest/sqlcode_test.go +++ b/sqltest/sqlcode_test.go @@ -27,7 +27,7 @@ func Test_Patch(t *testing.T) { require.NoError(t, SQL.EnsureUploaded(ctx, fixture.DB)) - fixture.RunIfMssql(t, "mssql", func(t *testing.T) { + fixture.RunIfMssql(t, "returns 1 affected row", func(t *testing.T) { patched := SQL.CodePatch(fixture.DB, `[code].Test`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) @@ -37,7 +37,8 @@ func Test_Patch(t *testing.T) { assert.Equal(t, int64(1), rowsAffected) }) - fixture.RunIfPostgres(t, "pgsql", func(t *testing.T) { + // TODO: instrument a test table to perform an update operation + fixture.RunIfPostgres(t, "returns 0 affected rows", func(t *testing.T) { patched := SQL.CodePatch(fixture.DB, `call [code].Test()`) res, err := fixture.DB.ExecContext(ctx, patched) require.NoError(t, err) @@ -47,7 +48,6 @@ func Test_Patch(t *testing.T) { require.NoError(t, err) assert.Equal(t, int64(0), rowsAffected) }) - } func Test_EnsureUploaded(t *testing.T) { @@ -55,7 +55,7 @@ func Test_EnsureUploaded(t *testing.T) { defer f.Teardown() ctx := context.Background() - f.RunIfMssql(t, "mssql", func(t *testing.T) { + f.RunIfMssql(t, "uploads schema", func(t *testing.T) { f.RunMigrationFile("../migrations/0001.sqlcode.sql") require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) schemas, err := SQL.ListUploaded(ctx, f.DB) @@ -64,7 +64,7 @@ func Test_EnsureUploaded(t *testing.T) { }) - f.RunIfPostgres(t, "pgsql", func(t *testing.T) { + f.RunIfPostgres(t, "uploads schema", func(t *testing.T) { f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") _, err := f.DB.Exec( @@ -77,3 +77,26 @@ func Test_EnsureUploaded(t *testing.T) { require.Len(t, schemas, 1) }) } + +func Test_DropAndUpload(t *testing.T) { + f := NewFixture() + defer f.Teardown() + ctx := context.Background() + + f.RunIfMssql(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.sql") + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) + + f.RunIfPostgres(t, "drop and upload", func(t *testing.T) { + f.RunMigrationFile("../migrations/0001.sqlcode.pgsql") + + _, err := f.DB.Exec( + fmt.Sprintf(`grant create on database "%s" to "sqlcode-definer-role"`, f.DBName)) + require.NoError(t, err) + + require.NoError(t, SQL.EnsureUploaded(ctx, f.DB)) + require.NoError(t, SQL.DropAndUpload(ctx, f.DB)) + }) +} From ad129b8164c38222abdcb1495ec33a77b1587ac5 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 18:48:02 +0100 Subject: [PATCH 18/28] Initial unit tests for T-SQL syntax parsing. --- sqlparser/document_test.go | 519 +++++++++++++++++++++++++++++++++++++ sqlparser/dom.go | 7 - sqlparser/scanner.go | 9 +- 3 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 sqlparser/document_test.go diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go new file mode 100644 index 0000000..5ddddb6 --- /dev/null +++ b/sqlparser/document_test.go @@ -0,0 +1,519 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_addError(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + + require.Len(t, doc.Errors, 1) + assert.Equal(t, "test error message", doc.Errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.Errors[0].Pos) + }) + + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.Errors, 2) + assert.Equal(t, "error 1", doc.Errors[0].Message) + assert.Equal(t, "error 2", doc.Errors[1].Message) + }) + + t.Run("creates error with token text", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() + + doc.unexpectedTokenError(s) + + require.Len(t, doc.Errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.Errors[0].Message) + }) +} + +func TestDocument_parseTypeExpression(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "int") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) + + t.Run("parses type with single arg", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) + + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) + + t.Run("parses type with max", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) + + t.Run("handles invalid arg", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.Errors) + }) + + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "123") + s.NextToken() + + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) + }) +} + +func TestDocument_parseDeclare(t *testing.T) { + t.Run("parses single enum declaration", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumStatus int = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@EnumStatus", declares[0].VariableName) + assert.Equal(t, "int", declares[0].Datatype.BaseType) + assert.Equal(t, "42", declares[0].Literal.RawValue) + }) + + t.Run("parses multiple declarations with comma", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 2) + assert.Equal(t, "@EnumA", declares[0].VariableName) + assert.Equal(t, "@EnumB", declares[1].VariableName) + }) + + t.Run("parses string literal", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "N'test'", declares[0].Literal.RawValue) + }) + + t.Run("errors on invalid variable name", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@InvalidName int = 1") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "@InvalidName") + }) + + t.Run("errors on missing type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumTest = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "type declared explicitly") + }) + + t.Run("errors on missing assignment", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@EnumTest int") + s.NextToken() + + doc.parseDeclare(s) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "needs to be assigned") + }) + + t.Run("accepts @Global prefix", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@GlobalSetting int = 100") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@GlobalSetting", declares[0].VariableName) + assert.Empty(t, doc.Errors) + }) + + t.Run("accepts @Const prefix", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "@ConstValue int = 200") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@ConstValue", declares[0].VariableName) + assert.Empty(t, doc.Errors) + }) +} + +func TestDocument_parseBatchSeparator(t *testing.T) { + t.Run("parses valid go separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "go\n") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.Empty(t, doc.Errors) + }) + + t.Run("errors on malformed separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "go -- comment") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "should be alone") + }) +} + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed an identifier") + }) +} + +func TestDocument_parseCreate(t *testing.T) { + t.Run("parses simple procedure", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "[TestProc]", create.QuotedName.Value) + assert.NotEmpty(t, create.Body) + }) + + t.Run("parses function", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "[TestFunc]", create.QuotedName.Value) + }) + + t.Run("parses type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create type [code].TestType as table (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + assert.Equal(t, "[TestType]", create.QuotedName.Value) + }) + + t.Run("tracks dependencies", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + assert.Equal(t, "[Table2]", create.DependsOn[1].Value) + }) + + t.Run("deduplicates dependencies", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 1) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + }) + + t.Run("errors on unsupported create type", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create table [code].TestTable (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "only supports creating procedures") + }) + + t.Run("errors on multiple procedures in batch", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 1) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be alone in a batch") + }) + + t.Run("errors on missing code schema", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.Errors) + assert.Contains(t, doc.Errors[0].Message, "must be followed by [code]") + }) + + t.Run("allows create index inside procedure", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Empty(t, doc.Errors) + }) + + t.Run("stops at batch separator", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "[Proc]", create.QuotedName.Value) + assert.Equal(t, BatchSeparatorToken, s.TokenType()) + }) + + t.Run("panics if not on create token", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "procedure") + s.NextToken() + + assert.Panics(t, func() { + doc.parseCreate(s, 0) + }) + }) +} + +func TestNextTokenCopyingWhitespace(t *testing.T) { + t.Run("copies whitespace tokens", func(t *testing.T) { + s := NewScanner("test.sql", " \n\t token") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("copies comments", func(t *testing.T) { + s := NewScanner("test.sql", "/* comment */ -- line\ntoken") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.True(t, len(target) >= 2) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", " ") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestCreateUnparsed(t *testing.T) { + t.Run("creates unparsed from scanner", func(t *testing.T) { + s := NewScanner("test.sql", "select") + s.NextToken() + + unparsed := CreateUnparsed(s) + + assert.Equal(t, ReservedWordToken, unparsed.Type) + assert.Equal(t, "select", unparsed.RawValue) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("recovers to go", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "error error go") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "go", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + doc := &Document{} + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + doc.recoverToNextStatementCopying(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 14209ee..22afdaa 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -129,13 +129,6 @@ func (e Error) WithoutPos() Error { return Error{Message: e.Message} } -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - func (c Create) Serialize(w io.StringWriter) error { for _, l := range c.Body { if _, err := w.WriteString(l.RawValue); err != nil { diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index 3103894..a5fb75a 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -1,11 +1,12 @@ package sqlparser import ( - "github.com/smasher164/xid" "regexp" "strings" "unicode" "unicode/utf8" + + "github.com/smasher164/xid" ) // dedicated type for reference to file, in case we need to refactor this later.. @@ -40,6 +41,10 @@ type Scanner struct { reservedWord string // in the event that the token is a ReservedWordToken, this contains the lower-case version } +func NewScanner(path FileRef, input string) *Scanner { + return &Scanner{input: input, file: path} +} + type TokenType int func (s *Scanner) TokenType() TokenType { @@ -316,7 +321,7 @@ func (s *Scanner) scanIdentifier() { s.curIndex = len(s.input) } -// DRY helper to handle both '' and ]] escapes +// DRY helper to handle both ” and ]] escapes func (s *Scanner) scanUntilSingleDoubleEscapes(endmarker rune, tokenType TokenType, unterminatedTokenType TokenType) TokenType { skipnext := false for i, r := range s.input[s.curIndex:] { From 352ed9fa4868b62636ec9ddc06c32cdd20781cc9 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:57:25 +0100 Subject: [PATCH 19/28] Refactored to use a Document interface. --- cli/cmd/constants.go | 8 +- cli/cmd/dep.go | 8 +- deployable.go | 6 +- preprocess.go | 8 +- sqlparser/create.go | 98 +++++++ sqlparser/document.go | 526 +++---------------------------------- sqlparser/document_test.go | 521 ++---------------------------------- sqlparser/dom.go | 154 +---------- sqlparser/parser.go | 33 +-- sqlparser/parser_test.go | 116 ++++---- 10 files changed, 249 insertions(+), 1229 deletions(-) create mode 100644 sqlparser/create.go diff --git a/cli/cmd/constants.go b/cli/cmd/constants.go index a71b364..90d071a 100644 --- a/cli/cmd/constants.go +++ b/cli/cmd/constants.go @@ -20,18 +20,18 @@ var ( if err != nil { return err } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } return nil } fmt.Println("declare") - for i, c := range d.CodeBase.Declares { + for i, c := range d.CodeBase.Declares() { var prefix string if i == 0 { prefix = " " diff --git a/cli/cmd/dep.go b/cli/cmd/dep.go index c0d110c..528b5b5 100644 --- a/cli/cmd/dep.go +++ b/cli/cmd/dep.go @@ -36,16 +36,16 @@ var ( fmt.Println() err = nil } - if len(d.CodeBase.Creates) == 0 && len(d.CodeBase.Declares) == 0 { + if d.CodeBase.Empty() { fmt.Println("No SQL code found in given paths") } - if len(d.CodeBase.Errors) > 0 { + if d.CodeBase.HasErrors() { fmt.Println("Errors:") - for _, e := range d.CodeBase.Errors { + for _, e := range d.CodeBase.Errors() { fmt.Printf("%s:%d:%d: %s\n", e.Pos.File, e.Pos.Line, e.Pos.Line, e.Message) } } - for _, c := range d.CodeBase.Creates { + for _, c := range d.CodeBase.Creates() { fmt.Println(c.QuotedName.String() + ":") if len(c.DependsOn) > 0 { fmt.Println(" Uses:") diff --git a/deployable.go b/deployable.go index 03cc0a8..135fd26 100644 --- a/deployable.go +++ b/deployable.go @@ -277,7 +277,7 @@ func (d *Deployable) IsUploadedFromCache(dbc DB) bool { // TODO: StringConst. This requires parsing a SQL literal, a bit too complex // to code up just-in-case func (d Deployable) IntConst(s string) (int, error) { - for _, declare := range d.CodeBase.Declares { + for _, declare := range d.CodeBase.Declares() { if declare.VariableName == s { // TODO: more robust integer SQL parsing than this; only works // in most common cases @@ -311,8 +311,8 @@ type Options struct { func Include(opts Options, fsys ...fs.FS) (result Deployable, err error) { parsedFiles, doc, err := sqlparser.ParseFilesystems(fsys, opts.IncludeTags) - if len(doc.Errors) > 0 && !opts.PartialParseResults { - return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors} + if doc.HasErrors() && !opts.PartialParseResults { + return Deployable{}, SQLCodeParseErrors{Errors: doc.Errors()} } result.CodeBase = doc diff --git a/preprocess.go b/preprocess.go index c4a776b..9478c3f 100644 --- a/preprocess.go +++ b/preprocess.go @@ -16,10 +16,10 @@ import ( func SchemaSuffixFromHash(doc sqlparser.Document) string { hasher := sha256.New() - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { hasher.Write([]byte(dec.String() + "\n")) } - for _, c := range doc.Creates { + for _, c := range doc.Creates() { if err := c.SerializeBytes(hasher); err != nil { panic(err) // asserting that sha256 will never return a write error... } @@ -149,7 +149,7 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive } declares := make(map[string]string) - for _, dec := range doc.Declares { + for _, dec := range doc.Declares() { declares[dec.VariableName] = dec.Literal.RawValue } @@ -164,7 +164,7 @@ func Preprocess(doc sqlparser.Document, schemasuffix string, driver driver.Drive target = fmt.Sprintf(`"code@%s"`, schemasuffix) } - for _, create := range doc.Creates { + for _, create := range doc.Creates() { if len(create.Body) == 0 { continue } diff --git a/sqlparser/create.go b/sqlparser/create.go new file mode 100644 index 0000000..0bdbbf2 --- /dev/null +++ b/sqlparser/create.go @@ -0,0 +1,98 @@ +package sqlparser + +import ( + "database/sql/driver" + "io" + "strings" + + "gopkg.in/yaml.v3" +) + +type Create struct { + CreateType string // "procedure", "function" or "type" + QuotedName PosString // proc/func/type name, including [] + Body []Unparsed + DependsOn []PosString + Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Driver driver.Driver // the sql driver this document is intended for +} + +func (c Create) DocstringAsString() string { + var result []string + for _, line := range c.Docstring { + result = append(result, line.Value) + } + return strings.Join(result, "\n") +} + +func (c Create) DocstringYamldoc() (string, error) { + var yamldoc []string + parsing := false + for _, line := range c.Docstring { + if strings.HasPrefix(line.Value, "--!") { + parsing = true + if !strings.HasPrefix(line.Value, "--! ") { + return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} + } + yamldoc = append(yamldoc, line.Value[4:]) + } else if parsing { + return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} + } + } + return strings.Join(yamldoc, "\n"), nil +} + +func (c Create) ParseYamlInDocstring(out any) error { + yamldoc, err := c.DocstringYamldoc() + if err != nil { + return err + } + return yaml.Unmarshal([]byte(yamldoc), out) +} + +func (c Create) Serialize(w io.StringWriter) error { + for _, l := range c.Body { + if _, err := w.WriteString(l.RawValue); err != nil { + return err + } + } + return nil +} + +func (c Create) SerializeBytes(w io.Writer) error { + for _, l := range c.Body { + if _, err := w.Write([]byte(l.RawValue)); err != nil { + return err + } + } + return nil +} + +func (c Create) String() string { + var buf strings.Builder + err := c.Serialize(&buf) + if err != nil { + panic(err) + } + return buf.String() +} + +func (c Create) WithoutPos() Create { + var body []Unparsed + for _, x := range c.Body { + body = append(body, x.WithoutPos()) + } + return Create{ + CreateType: c.CreateType, + QuotedName: c.QuotedName, + DependsOn: c.DependsOn, + Body: body, + } +} + +func (c Create) DependsOnStrings() (result []string) { + for _, x := range c.DependsOn { + result = append(result, x.Value) + } + return +} diff --git a/sqlparser/document.go b/sqlparser/document.go index c907aa8..21839ec 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -1,505 +1,45 @@ package sqlparser import ( - "fmt" - "sort" + "path/filepath" "strings" - - "github.com/jackc/pgx/v5/stdlib" - mssql "github.com/microsoft/go-mssqldb" ) -type Document struct { - PragmaIncludeIf []string - Creates []Create - Declares []Declare - Errors []Error -} - -func (d *Document) addError(s *Scanner, msg string) { - d.Errors = append(d.Errors, Error{ - Pos: s.Start(), - Message: msg, - }) -} - -func (d *Document) unexpectedTokenError(s *Scanner) { - d.addError(s, "Unexpected: "+s.Token()) -} - -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { - parseArgs := func() { - // parses *after* the initial (; consumes trailing ) - for { - switch { - case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) - case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - s.NextNonWhitespaceCommentToken() - switch { - case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() - return - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - } - } - - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") - } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() - } +// Document represents a parsed SQL document, containing +// declarations, create statements, pragmas, and errors. +// It provides methods to access and manipulate these components +// for T-SQL and PostgreSQL +type Document interface { + Empty() bool + HasErrors() bool + + Creates() []Create + Declares() []Declare + Errors() []Error + PragmaIncludeIf() []string + Include(other Document) + Sort() + ParsePragmas(s *Scanner) + ParseBatch(s *Scanner, isFirst bool) (hasMore bool) + + WithoutPos() Document +} + +// Helper function to parse a SQL document from a string input +func ParseString(filename FileRef, input string) (result Document) { + result = NewDocumentFromExtension(filepath.Ext(strings.ToLower(string(filename)))) + Parse(&Scanner{input: input, file: filename}, result) return } -func (doc *Document) parseDeclare(s *Scanner) (result []Declare) { - declareStart := s.Start() - // parse what is *after* the `declare` reserved keyword -loop: - for { - if s.TokenType() != VariableIdentifierToken { - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - variableName := s.Token() - if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && - !strings.HasPrefix(strings.ToLower(variableName), "@global") && - !strings.HasPrefix(strings.ToLower(variableName), "@const") { - doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) - } - s.NextNonWhitespaceCommentToken() - var variableType Type - switch s.TokenType() { - case EqualToken: - doc.addError(s, "sqlcode constants needs a type declared explicitly") - s.NextNonWhitespaceCommentToken() - case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) - } - - if s.TokenType() != EqualToken { - doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) - } - - switch s.NextNonWhitespaceCommentToken() { - case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ - Start: declareStart, - Stop: s.Stop(), - VariableName: variableName, - Datatype: variableType, - Literal: CreateUnparsed(s), - }) - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - return - } - - switch s.NextNonWhitespaceCommentToken() { - case CommaToken: - s.NextNonWhitespaceCommentToken() - continue - case SemicolonToken: - s.NextNonWhitespaceCommentToken() - break loop - default: - break loop - } - } - if len(result) == 0 { - doc.addError(s, "incorrect syntax; no variables successfully declared") - } - return -} - -func (doc *Document) parseBatchSeparator(s *Scanner) { - // just saw a 'go'; just make sure there's nothing bad trailing it - // (if there is, convert to errors and move on until the line is consumed - errorEmitted := false - for { - switch s.NextToken() { - case WhitespaceToken: - continue - case MalformedBatchSeparatorToken: - if !errorEmitted { - doc.addError(s, "`go` should be alone on a line without any comments") - errorEmitted = true - } - continue - default: - return - } - } -} - -func (doc *Document) parseDeclareBatch(s *Scanner) (hasMore bool) { - if s.ReservedWord() != "declare" { - panic("assertion failed, incorrect use in caller") - } - for { - tt := s.TokenType() - switch { - case tt == EOFToken: - return false - case tt == ReservedWordToken && s.ReservedWord() == "declare": - s.NextNonWhitespaceCommentToken() - d := doc.parseDeclare(s) - doc.Declares = append(doc.Declares, d...) - case tt == ReservedWordToken && s.ReservedWord() != "declare": - doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) - case tt == BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) - } - } -} - -func (doc *Document) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // First declare-statement; enter a mode where we assume all contents - // of batch are declare statements - if !isFirst { - doc.addError(s, "'declare' statement only allowed in first batch") - } - // regardless of errors, go on and parse as far as we get... - return doc.parseDeclareBatch(s) - case "create": - // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } - - // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.Creates = append(doc.Creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } - } -} - -func (d *Document) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } - } -} - -func (d *Document) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result +// Based on the input file extension, create the appropriate Document type +func NewDocumentFromExtension(extension string) Document { + switch extension { + case ".sql": + return &TSqlDocument{} + case ".pgsql": + return &PGSqlDocument{} default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } -} - -// parseCreate parses anything that starts with "create". Position is -// *on* the create token. -// At this stage in sqlcode parser development we're only interested -// in procedures/functions/types as opaque blocks of SQL code where -// we only track dependencies between them and their declared name; -// so we treat them with the same code. We consume until the end of -// the batch; only one declaration allowed per batch. Everything -// parsed here will also be added to `batch`. On any error, copying -// to batch stops / becomes erratic.. -func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Create) { - if s.ReservedWord() != "create" { - panic("illegal use by caller") - } - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - createType := strings.ToLower(s.Token()) - if !(createType == "procedure" || createType == "function" || createType == "type") { - d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) - return - } - - result.CreateType = createType - CopyToken(s, &result.Body) - - NextTokenCopyingWhitespace(s, &result.Body) - - // Insist on [code]. - if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { - d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) - return - } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) - if result.QuotedName.String() == "" { - return - } - - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies - // + some other details mentioned below - - //firstAs := true // See comment below on rowcount - -tailloop: - for { - tt := s.TokenType() - switch { - case tt == ReservedWordToken && s.ReservedWord() == "create": - // So, we're currently parsing 'create ...' and we see another 'create'. - // We split in two cases depending on the context we are currently in - // (createType is referring to how we entered this function, *NOT* the - // `create` statement we are looking at now - switch createType { // note: this is the *outer* create type, not the one of current scanner position - case "function", "procedure": - // Within a function/procedure we can allow 'create index', 'create table' and nothing - // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain - // about that aspect, not relevant for batch / dependency parsing) - // - // What is important is a function/procedure/type isn't started on without a 'go' - // in between; so we block those 3 from appearing in the same batch - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - tt2 := s.TokenType() - - if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || - (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) - d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - return - } - case "type": - // We allow more than one type creation in a batch; and 'create' can never appear - // scoped within 'create type'. So at a new create we are done with the previous - // one, and return it -- the caller can then re-enter this function from the top - break tailloop - default: - panic("assertion failed") - } - - case tt == EOFToken || tt == BatchSeparatorToken: - break tailloop - case tt == QuotedIdentifierToken && s.Token() == "[code]": - // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) - found := false - for _, existing := range result.DependsOn { - if existing.Value == dep.Value { - found = true - break - } - } - if !found { - result.DependsOn = append(result.DependsOn, dep) - } - case tt == ReservedWordToken && s.Token() == "as": - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - /* - TODO: Fix and re-enable - This code add RoutineName for convenience. So: - - create procedure [code@5420c0269aaf].Test as - begin - select 1 - end - go - - becomes: - - create procedure [code@5420c0269aaf].Test as - declare @RoutineName nvarchar(128) - set @RoutineName = 'Test' - begin - select 1 - end - go - - However, for some very strange reason, @@rowcount is 1 with the first version, - and it is 2 with the second version. - if firstAs { - // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name - // from inside the procedure (for example, when logging) - if result.CreateType == "procedure" { - procNameToken := Unparsed{ - Type: OtherToken, - RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), - } - result.Body = append(result.Body, procNameToken) - } - firstAs = false - } - */ - - default: - CopyToken(s, &result.Body) - NextTokenCopyingWhitespace(s, &result.Body) - } - } - - sort.Slice(result.DependsOn, func(i, j int) bool { - return result.DependsOn[i].Value < result.DependsOn[j].Value - }) - return -} - -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} - -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), + panic("unhandled document type: " + extension) } } diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go index 5ddddb6..8e497e6 100644 --- a/sqlparser/document_test.go +++ b/sqlparser/document_test.go @@ -7,513 +7,50 @@ import ( "github.com/stretchr/testify/require" ) -func TestDocument_addError(t *testing.T) { - t.Run("adds error with position", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "select") - s.NextToken() +func TestNewDocumentFromExtension(t *testing.T) { + t.Run("returns TSqlDocument for .sql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".sql") - doc.addError(s, "test error message") - - require.Len(t, doc.Errors, 1) - assert.Equal(t, "test error message", doc.Errors[0].Message) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.Errors[0].Pos) - }) - - t.Run("accumulates multiple errors", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "abc def") - s.NextToken() - doc.addError(s, "error 1") - s.NextToken() - doc.addError(s, "error 2") - - require.Len(t, doc.Errors, 2) - assert.Equal(t, "error 1", doc.Errors[0].Message) - assert.Equal(t, "error 2", doc.Errors[1].Message) + _, ok := doc.(*TSqlDocument) + assert.True(t, ok, "Expected TSqlDocument type") + assert.NotNil(t, doc) }) - t.Run("creates error with token text", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "unexpected_token") - s.NextToken() - - doc.unexpectedTokenError(s) + t.Run("returns PGSqlDocument for .pgsql extension", func(t *testing.T) { + doc := NewDocumentFromExtension(".pgsql") - require.Len(t, doc.Errors, 1) - assert.Equal(t, "Unexpected: unexpected_token", doc.Errors[0].Message) + _, ok := doc.(*PGSqlDocument) + assert.True(t, ok, "Expected PGSqlDocument type") + assert.NotNil(t, doc) }) -} - -func TestDocument_parseTypeExpression(t *testing.T) { - t.Run("parses simple type without args", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "int") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "int", typ.BaseType) - assert.Empty(t, typ.Args) - }) - - t.Run("parses type with single arg", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "varchar(50)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "varchar", typ.BaseType) - assert.Equal(t, []string{"50"}, typ.Args) - }) - - t.Run("parses type with multiple args", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "decimal(10, 2)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "decimal", typ.BaseType) - assert.Equal(t, []string{"10", "2"}, typ.Args) - }) - - t.Run("parses type with max", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "nvarchar(max)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "nvarchar", typ.BaseType) - assert.Equal(t, []string{"max"}, typ.Args) - }) - - t.Run("handles invalid arg", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "varchar(invalid)") - s.NextToken() - - typ := doc.parseTypeExpression(s) - - assert.Equal(t, "varchar", typ.BaseType) - assert.NotEmpty(t, doc.Errors) - }) - - t.Run("panics if not on identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "123") - s.NextToken() + t.Run("panics for unsupported extension", func(t *testing.T) { assert.Panics(t, func() { - doc.parseTypeExpression(s) - }) - }) -} - -func TestDocument_parseDeclare(t *testing.T) { - t.Run("parses single enum declaration", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumStatus int = 42") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@EnumStatus", declares[0].VariableName) - assert.Equal(t, "int", declares[0].Datatype.BaseType) - assert.Equal(t, "42", declares[0].Literal.RawValue) - }) - - t.Run("parses multiple declarations with comma", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 2) - assert.Equal(t, "@EnumA", declares[0].VariableName) - assert.Equal(t, "@EnumB", declares[1].VariableName) - }) - - t.Run("parses string literal", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "N'test'", declares[0].Literal.RawValue) - }) - - t.Run("errors on invalid variable name", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@InvalidName int = 1") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "@InvalidName") - }) - - t.Run("errors on missing type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumTest = 42") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "type declared explicitly") - }) - - t.Run("errors on missing assignment", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@EnumTest int") - s.NextToken() - - doc.parseDeclare(s) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "needs to be assigned") - }) - - t.Run("accepts @Global prefix", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@GlobalSetting int = 100") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@GlobalSetting", declares[0].VariableName) - assert.Empty(t, doc.Errors) - }) - - t.Run("accepts @Const prefix", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "@ConstValue int = 200") - s.NextToken() - - declares := doc.parseDeclare(s) - - require.Len(t, declares, 1) - assert.Equal(t, "@ConstValue", declares[0].VariableName) - assert.Empty(t, doc.Errors) - }) -} - -func TestDocument_parseBatchSeparator(t *testing.T) { - t.Run("parses valid go separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "go\n") - s.NextToken() - - doc.parseBatchSeparator(s) - - assert.Empty(t, doc.Errors) - }) - - t.Run("errors on malformed separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "go -- comment") - s.NextToken() - - doc.parseBatchSeparator(s) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "should be alone") - }) -} - -func TestDocument_parseCodeschemaName(t *testing.T) { - t.Run("parses unquoted identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[TestProc]", result.Value) - assert.NotEmpty(t, target) - }) - - t.Run("parses quoted identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].[Test Proc]") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[Test Proc]", result.Value) - }) - - t.Run("errors on missing dot", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code] TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed by '.'") + NewDocumentFromExtension(".txt") + }, "Expected panic for unsupported extension") }) - t.Run("errors on missing identifier", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "[code].123") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed an identifier") - }) -} - -func TestDocument_parseCreate(t *testing.T) { - t.Run("parses simple procedure", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Equal(t, "[TestProc]", create.QuotedName.Value) - assert.NotEmpty(t, create.Body) - }) - - t.Run("parses function", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - assert.Equal(t, "[TestFunc]", create.QuotedName.Value) - }) - - t.Run("parses type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create type [code].TestType as table (id int)") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "type", create.CreateType) - assert.Equal(t, "[TestType]", create.QuotedName.Value) - }) - - t.Run("tracks dependencies", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 2) - assert.Equal(t, "[Table1]", create.DependsOn[0].Value) - assert.Equal(t, "[Table2]", create.DependsOn[1].Value) - }) - - t.Run("deduplicates dependencies", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 1) - assert.Equal(t, "[Table1]", create.DependsOn[0].Value) - }) - - t.Run("errors on unsupported create type", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create table [code].TestTable (id int)") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 0) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "only supports creating procedures") - }) - - t.Run("errors on multiple procedures in batch", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 1) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be alone in a batch") - }) - - t.Run("errors on missing code schema", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - doc.parseCreate(s, 0) - - assert.NotEmpty(t, doc.Errors) - assert.Contains(t, doc.Errors[0].Message, "must be followed by [code]") - }) - - t.Run("allows create index inside procedure", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Empty(t, doc.Errors) - }) - - t.Run("stops at batch separator", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "[Proc]", create.QuotedName.Value) - assert.Equal(t, BatchSeparatorToken, s.TokenType()) - }) - - t.Run("panics if not on create token", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "procedure") - s.NextToken() - + t.Run("panics for empty extension", func(t *testing.T) { assert.Panics(t, func() { - doc.parseCreate(s, 0) - }) + NewDocumentFromExtension("") + }, "Expected panic for empty extension") }) -} - -func TestNextTokenCopyingWhitespace(t *testing.T) { - t.Run("copies whitespace tokens", func(t *testing.T) { - s := NewScanner("test.sql", " \n\t token") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.NotEmpty(t, target) - assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) - }) - - t.Run("copies comments", func(t *testing.T) { - s := NewScanner("test.sql", "/* comment */ -- line\ntoken") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.True(t, len(target) >= 2) - assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) - }) - - t.Run("stops at EOF", func(t *testing.T) { - s := NewScanner("test.sql", " ") - var target []Unparsed - - NextTokenCopyingWhitespace(s, &target) - - assert.Equal(t, EOFToken, s.TokenType()) - }) -} - -func TestCreateUnparsed(t *testing.T) { - t.Run("creates unparsed from scanner", func(t *testing.T) { - s := NewScanner("test.sql", "select") - s.NextToken() - - unparsed := CreateUnparsed(s) - assert.Equal(t, ReservedWordToken, unparsed.Type) - assert.Equal(t, "select", unparsed.RawValue) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) - }) -} - -func TestDocument_recoverToNextStatement(t *testing.T) { - t.Run("recovers to declare", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "declare", s.ReservedWord()) - }) - - t.Run("recovers to create", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "bad stuff create procedure") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "create", s.ReservedWord()) - }) - - t.Run("recovers to go", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "error error go") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "go", s.ReservedWord()) + t.Run("panics for unknown SQL extension", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".mysql") + }, "Expected panic for .mysql extension") }) - t.Run("stops at EOF", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "no keywords") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, EOFToken, s.TokenType()) + t.Run("extension matching is case insensitive", func(t *testing.T) { + assert.Panics(t, func() { + NewDocumentFromExtension(".SQL") + }, "Expected panic for uppercase .SQL") }) -} - -func TestDocument_recoverToNextStatementCopying(t *testing.T) { - t.Run("copies tokens while recovering", func(t *testing.T) { - doc := &Document{} - s := NewScanner("test.sql", "bad token declare") - s.NextToken() - var target []Unparsed - - doc.recoverToNextStatementCopying(s, &target) - assert.NotEmpty(t, target) - assert.Equal(t, "declare", s.ReservedWord()) + t.Run("returned documents implement Document interface", func(t *testing.T) { + sqlDoc := NewDocumentFromExtension(".sql") + pgsqlDoc := NewDocumentFromExtension(".pgsql") + require.NotEqual(t, sqlDoc, pgsqlDoc) }) } diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 22afdaa..0c72587 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -1,12 +1,8 @@ package sqlparser import ( - "database/sql/driver" "fmt" - "io" "strings" - - "gopkg.in/yaml.v3" ) type Unparsed struct { @@ -61,48 +57,6 @@ func (p PosString) String() string { return p.Value } -type Create struct { - CreateType string // "procedure", "function" or "type" - QuotedName PosString // proc/func/type name, including [] - Body []Unparsed - DependsOn []PosString - Docstring []PosString // comment lines before the create statement. Note: this is also part of Body - Driver driver.Driver // the sql driver this document is intended for -} - -func (c Create) DocstringAsString() string { - var result []string - for _, line := range c.Docstring { - result = append(result, line.Value) - } - return strings.Join(result, "\n") -} - -func (c Create) DocstringYamldoc() (string, error) { - var yamldoc []string - parsing := false - for _, line := range c.Docstring { - if strings.HasPrefix(line.Value, "--!") { - parsing = true - if !strings.HasPrefix(line.Value, "--! ") { - return "", Error{line.Pos, "YAML document in docstring; missing space after `--!`"} - } - yamldoc = append(yamldoc, line.Value[4:]) - } else if parsing { - return "", Error{line.Pos, "once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement"} - } - } - return strings.Join(yamldoc, "\n"), nil -} - -func (c Create) ParseYamlInDocstring(out any) error { - yamldoc, err := c.DocstringYamldoc() - if err != nil { - return err - } - return yaml.Unmarshal([]byte(yamldoc), out) -} - type Type struct { BaseType string Args []string @@ -129,107 +83,11 @@ func (e Error) WithoutPos() Error { return Error{Message: e.Message} } -func (c Create) Serialize(w io.StringWriter) error { - for _, l := range c.Body { - if _, err := w.WriteString(l.RawValue); err != nil { - return err - } - } - return nil -} - -func (c Create) SerializeBytes(w io.Writer) error { - for _, l := range c.Body { - if _, err := w.Write([]byte(l.RawValue)); err != nil { - return err - } - } - return nil -} - -func (c Create) String() string { - var buf strings.Builder - err := c.Serialize(&buf) - if err != nil { - panic(err) - } - return buf.String() -} - -func (c Create) WithoutPos() Create { - var body []Unparsed - for _, x := range c.Body { - body = append(body, x.WithoutPos()) - } - return Create{ - CreateType: c.CreateType, - QuotedName: c.QuotedName, - DependsOn: c.DependsOn, - Body: body, - } -} - -func (c Create) DependsOnStrings() (result []string) { - for _, x := range c.DependsOn { - result = append(result, x.Value) - } - return -} - -// Transform a Document to remove all Position information; this is used -// to 'unclutter' a DOM to more easily write assertions on it. -func (d Document) WithoutPos() Document { - var cs []Create - for _, x := range d.Creates { - cs = append(cs, x.WithoutPos()) - } - var ds []Declare - for _, x := range d.Declares { - ds = append(ds, x.WithoutPos()) - } - var es []Error - for _, x := range d.Errors { - es = append(es, x.WithoutPos()) - } - return Document{ - Creates: cs, - Declares: ds, - Errors: es, - } -} - -func (d *Document) Include(other Document) { - // Do not copy PragmaIncludeIf, since that is local to a single file. - // Its contents is also present in each Create. - d.Declares = append(d.Declares, other.Declares...) - d.Creates = append(d.Creates, other.Creates...) - d.Errors = append(d.Errors, other.Errors...) -} - -func (d *Document) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.PragmaIncludeIf = append(d.PragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *Document) parsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), } } - -func (d Document) Empty() bool { - return len(d.Creates) > 0 || len(d.Declares) > 0 -} diff --git a/sqlparser/parser.go b/sqlparser/parser.go index f2b51d8..1a49a08 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -44,7 +44,7 @@ func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { } } -func Parse(s *Scanner, result *Document) { +func Parse(s *Scanner, result Document) { // Top-level parse; this focuses on splitting into "batches" separated // by 'go'. @@ -62,20 +62,17 @@ func Parse(s *Scanner, result *Document) { // `s` will typically never be positioned on whitespace except in // whitespace-preserving parsing + filepath.Ext(s.input) + s.NextNonWhitespaceToken() - result.parsePragmas(s) - hasMore := result.parseBatch(s, true) + result.ParsePragmas(s) + hasMore := result.ParseBatch(s, true) for hasMore { - hasMore = result.parseBatch(s, false) + hasMore = result.ParseBatch(s, false) } return } -func ParseString(filename FileRef, input string) (result Document) { - Parse(&Scanner{input: input, file: filename}, &result) - return -} - // ParseFileystems iterates through a list of filesystems and parses all supported // SQL files and returns the combination of all of them. // @@ -129,10 +126,10 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } hashes[hash] = pathDesc - var fdoc Document - Parse(&Scanner{input: string(buf), file: FileRef(path)}, &fdoc) + fdoc := NewDocumentFromExtension(extension) + Parse(&Scanner{input: string(buf), file: FileRef(path)}, fdoc) - if matchesIncludeTags(fdoc.PragmaIncludeIf, includeTags) { + if matchesIncludeTags(fdoc.PragmaIncludeIf(), includeTags) { filenames = append(filenames, pathDesc) result.Include(fdoc) } @@ -144,17 +141,7 @@ func ParseFilesystems(fslst []fs.FS, includeTags []string) (filenames []string, } } - // Do the topological sort; and include any error with it as part - // of `result`, *not* return it as err - sortedCreates, errpos, sortErr := TopologicalSort(result.Creates) - if sortErr != nil { - result.Errors = append(result.Errors, Error{ - Pos: errpos, - Message: sortErr.Error(), - }) - } else { - result.Creates = sortedCreates - } + result.Sort() return } diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index 0e36b40..39f1e58 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -24,8 +24,8 @@ end; $$; `) - require.Len(t, doc.Creates, 1) - require.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + require.Len(t, doc.Creates(), 1) + require.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) } func TestParserSmokeTest(t *testing.T) { @@ -60,8 +60,8 @@ end; `) docNoPos := doc.WithoutPos() - require.Equal(t, 1, len(doc.Creates)) - c := doc.Creates[0] + require.Equal(t, 1, len(doc.Creates())) + c := doc.Creates()[0] require.Equal(t, &mssql.Driver{}, c.Driver) assert.Equal(t, "[TestFunc]", c.QuotedName.Value) @@ -82,7 +82,7 @@ end; { Message: "'declare' statement only allowed in first batch", }, - }, docNoPos.Errors) + }, docNoPos.Errors()) assert.Equal(t, []Declare{ @@ -130,7 +130,7 @@ end; }, }, }, - docNoPos.Declares, + docNoPos.Declares(), ) // repr.Println(doc) } @@ -150,21 +150,21 @@ create function [code].Two(); { Message: "a procedure/function must be alone in a batch; use 'go' to split batches", }, - }, doc.Errors) + }, doc.Errors()) } func TestBuggyDeclare(t *testing.T) { // this caused parses to infinitely loop; regression test... doc := ParseString("test.sql", `declare @EnumA int = 4 @EnumB tinyint = 5 @ENUM_C bigint = 435;`) - assert.Equal(t, 1, len(doc.Errors)) - assert.Equal(t, "Unexpected: @EnumB", doc.Errors[0].Message) + assert.Equal(t, 1, len(doc.Errors())) + assert.Equal(t, "Unexpected: @EnumB", doc.Errors()[0].Message) } func TestCreateType(t *testing.T) { doc := ParseString("test.sql", `create type [code].MyType as table (x int not null primary key);`) - assert.Equal(t, 1, len(doc.Creates)) - assert.Equal(t, "type", doc.Creates[0].CreateType) - assert.Equal(t, "[MyType]", doc.Creates[0].QuotedName.Value) + assert.Equal(t, 1, len(doc.Creates())) + assert.Equal(t, "type", doc.Creates()[0].CreateType) + assert.Equal(t, "[MyType]", doc.Creates()[0].QuotedName.Value) } func TestPragma(t *testing.T) { @@ -179,7 +179,7 @@ create procedure [code].ProcedureShouldAlsoHavePragmasAnnotated() func TestInfiniteLoopRegression(t *testing.T) { // success if we terminate!... doc := ParseString("test.sql", `@declare`) - assert.Equal(t, 1, len(doc.Errors)) + assert.Equal(t, 1, len(doc.Errors())) } func TestDeclareSeparation(t *testing.T) { @@ -190,7 +190,7 @@ func TestDeclareSeparation(t *testing.T) { doc := ParseString("test.sql", ` declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird int=3 declare @EnumFourth int=4;declare @EnumFifth int =5 `) - //repr.Println(doc.Declares) + //repr.Println(doc.Declares()) require.Equal(t, []Declare{ { VariableName: "@EnumFirst", @@ -217,7 +217,7 @@ declare @EnumFirst int = 3, @EnumSecond varchar(max) = 'hello'declare @EnumThird Datatype: Type{BaseType: "int"}, Literal: Unparsed{Type: NumberToken, RawValue: "5"}, }, - }, doc.WithoutPos().Declares) + }, doc.WithoutPos().Declares()) } func TestBatchDivisionsAndCreateStatements(t *testing.T) { @@ -232,7 +232,7 @@ go create type [code].Batch3 as table (x int); `) commentCount := 0 - for _, c := range doc.Creates { + for _, c := range doc.Creates() { for _, b := range c.Body { if strings.Contains(b.RawValue, "2nd") { commentCount++ @@ -251,13 +251,13 @@ create type [code].Type1 as table (x int); create type [code].Type2 as table (x int); create type [code].Type3 as table (x int); `) - require.Equal(t, 3, len(doc.Creates)) - assert.Equal(t, "[Type1]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Type3]", doc.Creates[2].QuotedName.Value) + require.Equal(t, 3, len(doc.Creates())) + assert.Equal(t, "[Type1]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Type3]", doc.Creates()[2].QuotedName.Value) // There was a bug that the last item in the body would be the 'create' // of the next statement; regression test.. - assert.Equal(t, "\n", doc.Creates[0].Body[len(doc.Creates[0].Body)-1].RawValue) - assert.Equal(t, "create", doc.Creates[1].Body[0].RawValue) + assert.Equal(t, "\n", doc.Creates()[0].Body[len(doc.Creates()[0].Body)-1].RawValue) + assert.Equal(t, "create", doc.Creates()[1].Body[0].RawValue) } func TestCreateProcs(t *testing.T) { @@ -270,10 +270,10 @@ create type [code].MyType () create procedure [code].MyProcedure () `) // First function and last procedure triggers errors. - require.Equal(t, 2, len(doc.Errors)) + require.Equal(t, 2, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) - assert.Equal(t, emsg, doc.Errors[1].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) + assert.Equal(t, emsg, doc.Errors()[1].Message) } @@ -283,14 +283,14 @@ func TestCreateProcs2(t *testing.T) { create type [code].MyType () create procedure [code].FirstProc as table (x int) `) - //repr.Println(doc.Errors) + //repr.Println(doc.Errors()) // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 1, len(doc.Errors)) + require.Equal(t, 1, len(doc.Errors())) emsg := "a procedure/function must be alone in a batch; use 'go' to split batches" - assert.Equal(t, emsg, doc.Errors[0].Message) + assert.Equal(t, emsg, doc.Errors()[0].Message) } func TestCreateProcsAndCheckForRoutineName(t *testing.T) { @@ -322,12 +322,12 @@ create procedure [code].[transform:safeguarding.Calculation/HEAD](@now datetime2 }, } for _, tc := range testcases { - require.Equal(t, 0, len(tc.doc.Errors)) - assert.Len(t, tc.doc.Creates, 1) - assert.Greater(t, len(tc.doc.Creates[0].Body), tc.expectedIndex) + require.Equal(t, 0, len(tc.doc.Errors())) + assert.Len(t, tc.doc.Creates(), 1) + assert.Greater(t, len(tc.doc.Creates()[0].Body), tc.expectedIndex) assert.Equal(t, fmt.Sprintf(templateRoutineName, tc.expectedProcName), - tc.doc.Creates[0].Body[tc.expectedIndex].RawValue, + tc.doc.Creates()[0].Body[tc.expectedIndex].RawValue, ) } } @@ -342,9 +342,9 @@ end // Code above was mainly to be able to step through parser in a given way. // First function triggers an error. Then create type is parsed which is // fine sharing a batch with others. - require.Equal(t, 2, len(doc.Errors)) - assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors[0].Message) - assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors[1].Message) + require.Equal(t, 2, len(doc.Errors())) + assert.Equal(t, "`go` should be alone on a line without any comments", doc.Errors()[0].Message) + assert.Equal(t, "Expected 'declare' or 'create', got: end", doc.Errors()[1].Message) } func TestCreateAnnotationHappyDay(t *testing.T) { @@ -362,8 +362,8 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- This is part of annotation\n--! key1: a\n--! key2: b\n--! key3: [1,2,3]", - doc.Creates[0].DocstringAsString()) - s, err := doc.Creates[0].DocstringYamldoc() + doc.Creates()[0].DocstringAsString()) + s, err := doc.Creates()[0].DocstringYamldoc() assert.NoError(t, err) assert.Equal(t, "key1: a\nkey2: b\nkey3: [1,2,3]", @@ -372,7 +372,7 @@ create procedure [code].Foo as begin end var x struct { Key1 string `yaml:"key1"` } - require.NoError(t, doc.Creates[0].ParseYamlInDocstring(&x)) + require.NoError(t, doc.Creates()[0].ParseYamlInDocstring(&x)) assert.Equal(t, "a", x.Key1) } @@ -387,7 +387,7 @@ create procedure [code].Foo as begin end `) assert.Equal(t, "-- docstring here", - doc.Creates[0].DocstringAsString()) + doc.Creates()[0].DocstringAsString()) } func TestCreateAnnotationErrors(t *testing.T) { @@ -397,7 +397,7 @@ func TestCreateAnnotationErrors(t *testing.T) { -- This comment after yamldoc is illegal; this also prevents multiple embedded YAML documents create procedure [code].Foo as begin end `) - _, err := doc.Creates[0].DocstringYamldoc() + _, err := doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 once embedded yaml document is started (lines prefixed with `--!`), it must continue until create statement", err.Error()) @@ -407,7 +407,7 @@ create procedure [code].Foo as begin end --!key4: 1 create procedure [code].Foo as begin end `) - _, err = doc.Creates[0].DocstringYamldoc() + _, err = doc.Creates()[0].DocstringYamldoc() assert.Equal(t, "test.sql:3:1 YAML document in docstring; missing space after `--!`", err.Error()) @@ -433,8 +433,8 @@ create function [code].Func1() returns int as begin return 1 end filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 2) - assert.Len(t, doc.Creates, 2) - assert.Len(t, doc.Declares, 1) + assert.Len(t, doc.Creates(), 2) + assert.Len(t, doc.Declares(), 1) }) t.Run("filters by include tags", func(t *testing.T) { @@ -455,8 +455,8 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "included.sql") - assert.Len(t, doc.Creates, 1) - assert.Equal(t, "[Included]", doc.Creates[0].QuotedName.Value) + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, "[Included]", doc.Creates()[0].QuotedName.Value) }) t.Run("detects duplicate files with same hash", func(t *testing.T) { @@ -488,7 +488,7 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "sqlcode.sql") - assert.Len(t, doc.Creates, 1) + assert.Len(t, doc.Creates(), 1) }) t.Run("skips hidden directories", func(t *testing.T) { @@ -508,7 +508,7 @@ create procedure [code].Excluded as begin end require.NoError(t, err) assert.Len(t, filenames, 1) assert.Contains(t, filenames[0], "visible.sql") - assert.Len(t, doc.Creates, 1) + assert.Len(t, doc.Creates(), 1) }) t.Run("handles dependencies and topological sort", func(t *testing.T) { @@ -524,10 +524,10 @@ create procedure [code].Excluded as begin end filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 2) - assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Creates(), 2) // Proc2 should come before Proc1 due to dependency - assert.Equal(t, "[Proc2]", doc.Creates[0].QuotedName.Value) - assert.Equal(t, "[Proc1]", doc.Creates[1].QuotedName.Value) + assert.Equal(t, "[Proc2]", doc.Creates()[0].QuotedName.Value) + assert.Equal(t, "[Proc1]", doc.Creates()[1].QuotedName.Value) }) t.Run("reports topological sort errors", func(t *testing.T) { @@ -541,9 +541,9 @@ create procedure [code].Excluded as begin end } _, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) - require.NoError(t, err) // filesystem error should be nil - assert.NotEmpty(t, doc.Errors) // but parsing errors should exist - assert.Contains(t, doc.Errors[0].Message, "Detected a dependency cycle") + require.NoError(t, err) // filesystem error should be nil + assert.NotEmpty(t, doc.Errors()) // but parsing errors should exist + assert.Contains(t, doc.Errors()[0].Message, "Detected a dependency cycle") }) t.Run("handles multiple filesystems", func(t *testing.T) { @@ -563,7 +563,7 @@ create procedure [code].Excluded as begin end assert.Len(t, filenames, 2) assert.Contains(t, filenames[0], "fs[0]:") assert.Contains(t, filenames[1], "fs[1]:") - assert.Len(t, doc.Creates, 2) + assert.Len(t, doc.Creates(), 2) }) t.Run("detects sqlcode files by pragma header", func(t *testing.T) { @@ -578,7 +578,7 @@ create procedure NotInCodeSchema.Test as begin end`), require.NoError(t, err) assert.Len(t, filenames, 1) // Should still parse even though it will have errors (not in [code] schema) - assert.NotEmpty(t, doc.Errors) + assert.NotEmpty(t, doc.Errors()) }) t.Run("handles pgsql extension", func(t *testing.T) { @@ -599,8 +599,8 @@ $$; filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Len(t, filenames, 1) - assert.Len(t, doc.Creates, 1) - assert.Equal(t, &stdlib.Driver{}, doc.Creates[0].Driver) + assert.Len(t, doc.Creates(), 1) + assert.Equal(t, &stdlib.Driver{}, doc.Creates()[0].Driver) }) t.Run("empty filesystem returns empty results", func(t *testing.T) { @@ -609,8 +609,8 @@ $$; filenames, doc, err := ParseFilesystems([]fs.FS{fsys}, nil) require.NoError(t, err) assert.Empty(t, filenames) - assert.Empty(t, doc.Creates) - assert.Empty(t, doc.Declares) + assert.Empty(t, doc.Creates()) + assert.Empty(t, doc.Declares()) }) } From 5e807d5e5d8ec73ae16c296590f3dd1e7c94f3ea Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:58:05 +0100 Subject: [PATCH 20/28] Renamed the existing Document struct to be specific for T-SQL. --- sqlparser/tsql_document.go | 588 ++++++++++++++++++++++++ sqlparser/tsql_document_test.go | 764 ++++++++++++++++++++++++++++++++ 2 files changed, 1352 insertions(+) create mode 100644 sqlparser/tsql_document.go create mode 100644 sqlparser/tsql_document_test.go diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go new file mode 100644 index 0000000..343ef7c --- /dev/null +++ b/sqlparser/tsql_document.go @@ -0,0 +1,588 @@ +package sqlparser + +import ( + "fmt" + "sort" + "strings" + + "github.com/jackc/pgx/v5/stdlib" + mssql "github.com/microsoft/go-mssqldb" +) + +type TSqlDocument struct { + pragmaIncludeIf []string + creates []Create + declares []Declare + errors []Error +} + +func (d TSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d TSqlDocument) Creates() []Create { + return d.creates +} + +func (d TSqlDocument) Declares() []Declare { + return d.declares +} + +func (d TSqlDocument) Errors() []Error { + return d.errors +} +func (d TSqlDocument) PragmaIncludeIf() []string { + return d.pragmaIncludeIf +} + +func (d *TSqlDocument) Sort() { + // Do the topological sort; and include any error with it as part + // of `result`, *not* return it as err + sortedCreates, errpos, sortErr := TopologicalSort(d.creates) + + if sortErr != nil { + d.errors = append(d.errors, Error{ + Pos: errpos, + Message: sortErr.Error(), + }) + } else { + d.creates = sortedCreates + } +} + +// Transform a TSqlDocument to remove all Position information; this is used +// to 'unclutter' a DOM to more easily write assertions on it. +func (d TSqlDocument) WithoutPos() Document { + var cs []Create + for _, x := range d.creates { + cs = append(cs, x.WithoutPos()) + } + var ds []Declare + for _, x := range d.declares { + ds = append(ds, x.WithoutPos()) + } + var es []Error + for _, x := range d.errors { + es = append(es, x.WithoutPos()) + } + return &TSqlDocument{ + creates: cs, + declares: ds, + errors: es, + } +} + +func (d *TSqlDocument) Include(other Document) { + // Do not copy pragmaIncludeIf, since that is local to a single file. + // Its contents is also present in each Create. + d.declares = append(d.declares, other.Declares()...) + d.creates = append(d.creates, other.Creates()...) + d.errors = append(d.errors, other.Errors()...) +} + +func (d *TSqlDocument) parseSinglePragma(s *Scanner) { + pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) + if pragma == "" { + return + } + parts := strings.Split(pragma, " ") + if len(parts) != 2 { + d.addError(s, "Illegal pragma: "+s.Token()) + return + } + if parts[0] != "include-if" { + d.addError(s, "Illegal pragma: "+s.Token()) + return + } + d.pragmaIncludeIf = append(d.pragmaIncludeIf, strings.Split(parts[1], ",")...) +} + +func (d *TSqlDocument) ParsePragmas(s *Scanner) { + for s.TokenType() == PragmaToken { + d.parseSinglePragma(s) + s.NextNonWhitespaceToken() + } +} + +func (d TSqlDocument) Empty() bool { + return len(d.creates) == 0 || len(d.declares) == 0 +} + +func (d *TSqlDocument) addError(s *Scanner, msg string) { + d.errors = append(d.errors, Error{ + Pos: s.Start(), + Message: msg, + }) +} + +func (d *TSqlDocument) unexpectedTokenError(s *Scanner) { + d.addError(s, "Unexpected: "+s.Token()) +} + +func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { + parseArgs := func() { + // parses *after* the initial (; consumes trailing ) + for { + switch { + case s.TokenType() == NumberToken: + t.Args = append(t.Args, s.Token()) + case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": + t.Args = append(t.Args, "max") + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + s.NextNonWhitespaceCommentToken() + switch { + case s.TokenType() == CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case s.TokenType() == RightParenToken: + s.NextNonWhitespaceCommentToken() + return + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + } + } + + if s.TokenType() != UnquotedIdentifierToken { + panic("assertion failed, bug in caller") + } + t.BaseType = s.Token() + s.NextNonWhitespaceCommentToken() + if s.TokenType() == LeftParenToken { + s.NextNonWhitespaceCommentToken() + parseArgs() + } + return +} + +func (doc *TSqlDocument) parseDeclare(s *Scanner) (result []Declare) { + declareStart := s.Start() + // parse what is *after* the `declare` reserved keyword +loop: + for { + if s.TokenType() != VariableIdentifierToken { + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + variableName := s.Token() + if !strings.HasPrefix(strings.ToLower(variableName), "@enum") && + !strings.HasPrefix(strings.ToLower(variableName), "@global") && + !strings.HasPrefix(strings.ToLower(variableName), "@const") { + doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) + } + s.NextNonWhitespaceCommentToken() + var variableType Type + switch s.TokenType() { + case EqualToken: + doc.addError(s, "sqlcode constants needs a type declared explicitly") + s.NextNonWhitespaceCommentToken() + case UnquotedIdentifierToken: + variableType = doc.parseTypeExpression(s) + } + + if s.TokenType() != EqualToken { + doc.addError(s, "sqlcode constants needs to be assigned at once using =") + doc.recoverToNextStatement(s) + } + + switch s.NextNonWhitespaceCommentToken() { + case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: + result = append(result, Declare{ + Start: declareStart, + Stop: s.Stop(), + VariableName: variableName, + Datatype: variableType, + Literal: CreateUnparsed(s), + }) + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + return + } + + switch s.NextNonWhitespaceCommentToken() { + case CommaToken: + s.NextNonWhitespaceCommentToken() + continue + case SemicolonToken: + s.NextNonWhitespaceCommentToken() + break loop + default: + break loop + } + } + if len(result) == 0 { + doc.addError(s, "incorrect syntax; no variables successfully declared") + } + return +} + +func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { + // just saw a 'go'; just make sure there's nothing bad trailing it + // (if there is, convert to errors and move on until the line is consumed + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + doc.addError(s, "`go` should be alone on a line without any comments") + errorEmitted = true + } + continue + default: + return + } + } +} + +func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + if s.ReservedWord() != "declare" { + panic("assertion failed, incorrect use in caller") + } + for { + tt := s.TokenType() + switch { + case tt == EOFToken: + return false + case tt == ReservedWordToken && s.ReservedWord() == "declare": + s.NextNonWhitespaceCommentToken() + d := doc.parseDeclare(s) + doc.declares = append(doc.declares, d...) + case tt == ReservedWordToken && s.ReservedWord() != "declare": + doc.addError(s, "Only 'declare' allowed in this batch") + doc.recoverToNextStatement(s) + case tt == BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + doc.recoverToNextStatement(s) + } + } +} + +func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + + var createCountInBatch int + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // First declare-statement; enter a mode where we assume all contents + // of batch are declare statements + if !isFirst { + doc.addError(s, "'declare' statement only allowed in first batch") + } + // regardless of errors, go on and parse as far as we get... + return doc.parseDeclareBatch(s) + case "create": + // should be start of create procedure or create function... + c := doc.parseCreate(s, createCountInBatch) + + if strings.HasSuffix(string(s.file), ".sql") { + c.Driver = &mssql.Driver{} + } + if strings.HasSuffix(string(s.file), ".pgsql") { + c.Driver = &stdlib.Driver{} + } + + // *prepend* what we saw before getting to the 'create' + createCountInBatch++ + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.creates = append(doc.creates, c) + default: + doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) + s.NextToken() + } + case BatchSeparatorToken: + doc.parseBatchSeparator(s) + return true + default: + doc.unexpectedTokenError(s) + s.NextToken() + docstring = nil + } + } +} + +func (d *TSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func (d *TSqlDocument) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "declare", "create", "go": + return + } + case EOFToken: + return + } + } +} + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func (d *TSqlDocument) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result + default: + d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) + d.recoverToNextStatementCopying(s, target) + return PosString{Value: ""} + } +} + +// parseCreate parses anything that starts with "create". Position is +// *on* the create token. +// At this stage in sqlcode parser development we're only interested +// in procedures/functions/types as opaque blocks of SQL code where +// we only track dependencies between them and their declared name; +// so we treat them with the same code. We consume until the end of +// the batch; only one declaration allowed per batch. Everything +// parsed here will also be added to `batch`. On any error, copying +// to batch stops / becomes erratic.. +func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + if s.ReservedWord() != "create" { + panic("illegal use by caller") + } + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + createType := strings.ToLower(s.Token()) + if !(createType == "procedure" || createType == "function" || createType == "type") { + d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + d.recoverToNextStatementCopying(s, &result.Body) + return + } + + result.CreateType = createType + CopyToken(s, &result.Body) + + NextTokenCopyingWhitespace(s, &result.Body) + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + d.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = d.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // We have matched "create [code]."; at this + // point we copy the rest until the batch ends; *but* track dependencies + // + some other details mentioned below + + //firstAs := true // See comment below on rowcount + +tailloop: + for { + tt := s.TokenType() + switch { + case tt == ReservedWordToken && s.ReservedWord() == "create": + // So, we're currently parsing 'create ...' and we see another 'create'. + // We split in two cases depending on the context we are currently in + // (createType is referring to how we entered this function, *NOT* the + // `create` statement we are looking at now + switch createType { // note: this is the *outer* create type, not the one of current scanner position + case "function", "procedure": + // Within a function/procedure we can allow 'create index', 'create table' and nothing + // else. (Well, only procedures can have them, but we'll leave it to T-SQL to complain + // about that aspect, not relevant for batch / dependency parsing) + // + // What is important is a function/procedure/type isn't started on without a 'go' + // in between; so we block those 3 from appearing in the same batch + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + tt2 := s.TokenType() + + if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || + (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { + d.recoverToNextStatementCopying(s, &result.Body) + d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") + return + } + case "type": + // We allow more than one type creation in a batch; and 'create' can never appear + // scoped within 'create type'. So at a new create we are done with the previous + // one, and return it -- the caller can then re-enter this function from the top + break tailloop + default: + panic("assertion failed") + } + + case tt == EOFToken || tt == BatchSeparatorToken: + break tailloop + case tt == QuotedIdentifierToken && s.Token() == "[code]": + // Parse a dependency + dep := d.parseCodeschemaName(s, &result.Body) + found := false + for _, existing := range result.DependsOn { + if existing.Value == dep.Value { + found = true + break + } + } + if !found { + result.DependsOn = append(result.DependsOn, dep) + } + case tt == ReservedWordToken && s.Token() == "as": + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + /* + TODO: Fix and re-enable + This code add RoutineName for convenience. So: + + create procedure [code@5420c0269aaf].Test as + begin + select 1 + end + go + + becomes: + + create procedure [code@5420c0269aaf].Test as + declare @RoutineName nvarchar(128) + set @RoutineName = 'Test' + begin + select 1 + end + go + + However, for some very strange reason, @@rowcount is 1 with the first version, + and it is 2 with the second version. + if firstAs { + // Add the `RoutineName` token as a convenience, so that we can refer to the procedure/function name + // from inside the procedure (for example, when logging) + if result.CreateType == "procedure" { + procNameToken := Unparsed{ + Type: OtherToken, + RawValue: fmt.Sprintf(templateRoutineName, strings.Trim(result.QuotedName.Value, "[]")), + } + result.Body = append(result.Body, procNameToken) + } + firstAs = false + } + */ + + default: + CopyToken(s, &result.Body) + NextTokenCopyingWhitespace(s, &result.Body) + } + } + + sort.Slice(result.DependsOn, func(i, j int) bool { + return result.DependsOn[i].Value < result.DependsOn[j].Value + }) + return +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go new file mode 100644 index 0000000..b4ef303 --- /dev/null +++ b/sqlparser/tsql_document_test.go @@ -0,0 +1,764 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_addError(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + require.True(t, doc.HasErrors()) + assert.Equal(t, "test error message", doc.errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) + }) + + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.errors, 2) + assert.Equal(t, "error 1", doc.errors[0].Message) + assert.Equal(t, "error 2", doc.errors[1].Message) + }) + + t.Run("creates error with token text", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() + + doc.unexpectedTokenError(s) + + require.Len(t, doc.errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + }) +} + +func TestDocument_parseTypeExpression(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "int") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) + + t.Run("parses type with single arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) + + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) + + t.Run("parses type with max", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) + + t.Run("handles invalid arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.errors) + }) + + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "123") + s.NextToken() + + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) + }) +} + +func TestDocument_parseDeclare(t *testing.T) { + t.Run("parses single enum declaration", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumStatus int = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@EnumStatus", declares[0].VariableName) + assert.Equal(t, "int", declares[0].Datatype.BaseType) + assert.Equal(t, "42", declares[0].Literal.RawValue) + }) + + t.Run("parses multiple declarations with comma", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumA int = 1, @EnumB int = 2;") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 2) + assert.Equal(t, "@EnumA", declares[0].VariableName) + assert.Equal(t, "@EnumB", declares[1].VariableName) + }) + + t.Run("parses string literal", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumName nvarchar(50) = N'test'") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "N'test'", declares[0].Literal.RawValue) + }) + + t.Run("errors on invalid variable name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@InvalidName int = 1") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "@InvalidName") + }) + + t.Run("errors on missing type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest = 42") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "type declared explicitly") + }) + + t.Run("errors on missing assignment", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@EnumTest int") + s.NextToken() + + doc.parseDeclare(s) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "needs to be assigned") + }) + + t.Run("accepts @Global prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@GlobalSetting int = 100") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@GlobalSetting", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) + + t.Run("accepts @Const prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "@ConstValue int = 200") + s.NextToken() + + declares := doc.parseDeclare(s) + + require.Len(t, declares, 1) + assert.Equal(t, "@ConstValue", declares[0].VariableName) + assert.Empty(t, doc.errors) + }) +} + +func TestDocument_parseBatchSeparator(t *testing.T) { + t.Run("parses valid go separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go\n") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.Empty(t, doc.errors) + }) + + t.Run("errors on malformed separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "go -- comment") + s.NextToken() + + doc.parseBatchSeparator(s) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "should be alone") + }) +} + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result := doc.parseCodeschemaName(s, &target) + + assert.Equal(t, "", result.Value) + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed an identifier") + }) +} + +func TestDocument_parseCreate(t *testing.T) { + t.Run("parses simple procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "[TestProc]", create.QuotedName.Value) + assert.NotEmpty(t, create.Body) + }) + + t.Run("parses function", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create function [code].TestFunc() returns int as begin return 1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "[TestFunc]", create.QuotedName.Value) + }) + + t.Run("parses type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create type [code].TestType as table (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + assert.Equal(t, "[TestType]", create.QuotedName.Value) + }) + + t.Run("tracks dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1 join [code].Table2 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + assert.Equal(t, "[Table2]", create.DependsOn[1].Value) + }) + + t.Run("deduplicates dependencies", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc1 as begin select * from [code].Table1; select * from [code].Table1 end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 1) + assert.Equal(t, "[Table1]", create.DependsOn[0].Value) + }) + + t.Run("errors on unsupported create type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create table [code].TestTable (id int)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "only supports creating procedures") + }) + + t.Run("errors on multiple procedures in batch", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc2 as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 1) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be alone in a batch") + }) + + t.Run("errors on missing code schema", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure dbo.TestProc as begin end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + doc.parseCreate(s, 0) + + assert.NotEmpty(t, doc.errors) + assert.Contains(t, doc.errors[0].Message, "must be followed by [code]") + }) + + t.Run("allows create index inside procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin create index IX_Test on #temp(id) end") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Empty(t, doc.errors) + }) + + t.Run("stops at batch separator", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "create procedure [code].Proc as begin end\ngo") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "[Proc]", create.QuotedName.Value) + assert.Equal(t, BatchSeparatorToken, s.TokenType()) + }) + + t.Run("panics if not on create token", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "procedure") + s.NextToken() + + assert.Panics(t, func() { + doc.parseCreate(s, 0) + }) + }) +} + +func TestNextTokenCopyingWhitespace(t *testing.T) { + t.Run("copies whitespace tokens", func(t *testing.T) { + s := NewScanner("test.sql", " \n\t token") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("copies comments", func(t *testing.T) { + s := NewScanner("test.sql", "/* comment */ -- line\ntoken") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.True(t, len(target) >= 2) + assert.Equal(t, UnquotedIdentifierToken, s.TokenType()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", " ") + var target []Unparsed + + NextTokenCopyingWhitespace(s, &target) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestCreateUnparsed(t *testing.T) { + t.Run("creates unparsed from scanner", func(t *testing.T) { + s := NewScanner("test.sql", "select") + s.NextToken() + + unparsed := CreateUnparsed(s) + + assert.Equal(t, ReservedWordToken, unparsed.Type) + assert.Equal(t, "select", unparsed.RawValue) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("recovers to go", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "error error go") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, "go", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + doc.recoverToNextStatement(s) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + doc.recoverToNextStatementCopying(s, &target) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} + +func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { + t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "test_func", create.QuotedName.Value) + }) + + t.Run("parses PostgreSQL procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "insert_data", create.QuotedName.Value) + }) + + t.Run("parses CREATE OR REPLACE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses schema-qualified name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "test_func") + }) + + t.Run("parses RETURNS TABLE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("tracks dependencies with schema prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + }) + + t.Run("parses volatility categories", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses PARALLEL SAFE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +// func TestDocument_PostgreSQL17_Types(t *testing.T) { +// t.Run("parses composite type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) + +// t.Run("parses enum type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) + +// t.Run("parses range type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "type", create.CreateType) +// }) +// } + +// func TestDocument_PostgreSQL17_Extensions(t *testing.T) { +// t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "function", create.CreateType) +// }) + +// t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Equal(t, "function", create.CreateType) +// }) +// } + +// func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { +// t.Run("parses double-quoted identifiers", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Contains(t, create.QuotedName.Value, "Test Func") +// }) + +// t.Run("parses case-sensitive identifiers", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create := doc.parseCreate(s, 0) + +// assert.Contains(t, create.QuotedName.Value, "TestFunc") +// }) +// } + +// func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { +// t.Run("parses array types", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "integer[]") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "integer[]", typ.BaseType) +// }) + +// t.Run("parses serial types", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "serial") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "serial", typ.BaseType) +// }) + +// t.Run("parses text type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "text") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "text", typ.BaseType) +// }) + +// t.Run("parses jsonb type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "jsonb") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "jsonb", typ.BaseType) +// }) + +// t.Run("parses uuid type", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "uuid") +// s.NextToken() + +// typ := doc.parseTypeExpression(s) + +// assert.Equal(t, "uuid", typ.BaseType) +// }) +// } + +// func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { +// t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { +// doc := &TSqlDocument{} +// s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") +// s.NextToken() +// s.NextNonWhitespaceCommentToken() + +// create1 := doc.parseCreate(s, 0) +// assert.Equal(t, "test1", create1.QuotedName.Value) + +// // Move to next statement +// s.NextNonWhitespaceCommentToken() +// s.NextNonWhitespaceCommentToken() + +// create2 := doc.parseCreate(s, 1) +// assert.Equal(t, "test2", create2.QuotedName.Value) +// }) +// } From 494dca9c945fa54fafc00406b57fc72d553c00bb Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 19:58:27 +0100 Subject: [PATCH 21/28] Created initial PGSqlDocument for PostgreSQL. --- sqlparser/pgsql_document.go | 50 +++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 sqlparser/pgsql_document.go diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go new file mode 100644 index 0000000..9aacf98 --- /dev/null +++ b/sqlparser/pgsql_document.go @@ -0,0 +1,50 @@ +package sqlparser + +type PGSqlDocument struct { + pragmaIncludeIf []string + creates []Create + errors []Error +} + +func (d PGSqlDocument) HasErrors() bool { + return len(d.errors) > 0 +} + +func (d PGSqlDocument) Creates() []Create { + return d.creates +} + +func (d PGSqlDocument) Declares() []Declare { + return nil +} + +func (d PGSqlDocument) Errors() []Error { + return d.errors +} +func (d PGSqlDocument) PragmaIncludeIf() []string { + return d.pragmaIncludeIf +} + +func (d PGSqlDocument) Empty() bool { + return len(d.creates) == 0 +} + +func (d PGSqlDocument) Sort() { + +} + +func (d PGSqlDocument) Include(other Document) { + +} + +func (d PGSqlDocument) ParsePragmas(s *Scanner) { + +} + +func (d PGSqlDocument) WithoutPos() Document { + return &PGSqlDocument{} +} + +func (d PGSqlDocument) ParseBatch(s *Scanner, isFirst bool) bool { + return false +} From 1f7b6b715e78eef8a2e256544b6f91938ef03d93 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 20:13:29 +0100 Subject: [PATCH 22/28] Updated unit test. --- sqlparser/tsql_document.go | 6 +- sqlparser/tsql_document_test.go | 320 +++++++++----------------------- 2 files changed, 90 insertions(+), 236 deletions(-) diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 343ef7c..23d43a4 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -178,6 +178,7 @@ loop: !strings.HasPrefix(strings.ToLower(variableName), "@const") { doc.addError(s, "sqlcode constants needs to have names starting with @Enum, @Global or @Const: "+variableName) } + s.NextNonWhitespaceCommentToken() var variableType Type switch s.TokenType() { @@ -195,13 +196,14 @@ loop: switch s.NextNonWhitespaceCommentToken() { case NumberToken, NVarcharLiteralToken, VarcharLiteralToken: - result = append(result, Declare{ + declare := Declare{ Start: declareStart, Stop: s.Stop(), VariableName: variableName, Datatype: variableType, Literal: CreateUnparsed(s), - }) + } + result = append(result, declare) default: doc.unexpectedTokenError(s) doc.recoverToNextStatement(s) diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index b4ef303..0d5256c 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -7,106 +7,108 @@ import ( "github.com/stretchr/testify/require" ) -func TestDocument_addError(t *testing.T) { - t.Run("adds error with position", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "select") - s.NextToken() - - doc.addError(s, "test error message") - require.True(t, doc.HasErrors()) - assert.Equal(t, "test error message", doc.errors[0].Message) - assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) - }) - - t.Run("accumulates multiple errors", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "abc def") - s.NextToken() - doc.addError(s, "error 1") - s.NextToken() - doc.addError(s, "error 2") +func TestTSqlDocument(t *testing.T) { + t.Run("addError", func(t *testing.T) { + t.Run("adds error with position", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "select") + s.NextToken() + + doc.addError(s, "test error message") + require.True(t, doc.HasErrors()) + assert.Equal(t, "test error message", doc.errors[0].Message) + assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, doc.errors[0].Pos) + }) - require.Len(t, doc.errors, 2) - assert.Equal(t, "error 1", doc.errors[0].Message) - assert.Equal(t, "error 2", doc.errors[1].Message) - }) + t.Run("accumulates multiple errors", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "abc def") + s.NextToken() + doc.addError(s, "error 1") + s.NextToken() + doc.addError(s, "error 2") + + require.Len(t, doc.errors, 2) + assert.Equal(t, "error 1", doc.errors[0].Message) + assert.Equal(t, "error 2", doc.errors[1].Message) + }) - t.Run("creates error with token text", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "unexpected_token") - s.NextToken() + t.Run("creates error with token text", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "unexpected_token") + s.NextToken() - doc.unexpectedTokenError(s) + doc.unexpectedTokenError(s) - require.Len(t, doc.errors, 1) - assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + require.Len(t, doc.errors, 1) + assert.Equal(t, "Unexpected: unexpected_token", doc.errors[0].Message) + }) }) -} -func TestDocument_parseTypeExpression(t *testing.T) { - t.Run("parses simple type without args", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "int") - s.NextToken() + t.Run("parseTypeExpression", func(t *testing.T) { + t.Run("parses simple type without args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "int") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "int", typ.BaseType) - assert.Empty(t, typ.Args) - }) + assert.Equal(t, "int", typ.BaseType) + assert.Empty(t, typ.Args) + }) - t.Run("parses type with single arg", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "varchar(50)") - s.NextToken() + t.Run("parses type with single arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(50)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "varchar", typ.BaseType) - assert.Equal(t, []string{"50"}, typ.Args) - }) + assert.Equal(t, "varchar", typ.BaseType) + assert.Equal(t, []string{"50"}, typ.Args) + }) - t.Run("parses type with multiple args", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "decimal(10, 2)") - s.NextToken() + t.Run("parses type with multiple args", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "decimal(10, 2)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "decimal", typ.BaseType) - assert.Equal(t, []string{"10", "2"}, typ.Args) - }) + assert.Equal(t, "decimal", typ.BaseType) + assert.Equal(t, []string{"10", "2"}, typ.Args) + }) - t.Run("parses type with max", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "nvarchar(max)") - s.NextToken() + t.Run("parses type with max", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "nvarchar(max)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "nvarchar", typ.BaseType) - assert.Equal(t, []string{"max"}, typ.Args) - }) + assert.Equal(t, "nvarchar", typ.BaseType) + assert.Equal(t, []string{"max"}, typ.Args) + }) - t.Run("handles invalid arg", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "varchar(invalid)") - s.NextToken() + t.Run("handles invalid arg", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "varchar(invalid)") + s.NextToken() - typ := doc.parseTypeExpression(s) + typ := doc.parseTypeExpression(s) - assert.Equal(t, "varchar", typ.BaseType) - assert.NotEmpty(t, doc.errors) - }) + assert.Equal(t, "varchar", typ.BaseType) + assert.NotEmpty(t, doc.errors) + }) - t.Run("panics if not on identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "123") - s.NextToken() + t.Run("panics if not on identifier", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.sql", "123") + s.NextToken() - assert.Panics(t, func() { - doc.parseTypeExpression(s) + assert.Panics(t, func() { + doc.parseTypeExpression(s) + }) }) }) } @@ -155,6 +157,9 @@ func TestDocument_parseDeclare(t *testing.T) { declares := doc.parseDeclare(s) + // in this case when we detect the missing prefix, + // we add an error and continue parsing the declaration. + // this results with it being added require.Len(t, declares, 1) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "@InvalidName") @@ -167,7 +172,7 @@ func TestDocument_parseDeclare(t *testing.T) { declares := doc.parseDeclare(s) - require.Len(t, declares, 1) + require.Len(t, declares, 0) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "type declared explicitly") }) @@ -177,8 +182,9 @@ func TestDocument_parseDeclare(t *testing.T) { s := NewScanner("test.sql", "@EnumTest int") s.NextToken() - doc.parseDeclare(s) + declares := doc.parseDeclare(s) + require.Len(t, declares, 0) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "needs to be assigned") }) @@ -608,157 +614,3 @@ func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { assert.Equal(t, "function", create.CreateType) }) } - -// func TestDocument_PostgreSQL17_Types(t *testing.T) { -// t.Run("parses composite type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) - -// t.Run("parses enum type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) - -// t.Run("parses range type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "type", create.CreateType) -// }) -// } - -// func TestDocument_PostgreSQL17_Extensions(t *testing.T) { -// t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "function", create.CreateType) -// }) - -// t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Equal(t, "function", create.CreateType) -// }) -// } - -// func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { -// t.Run("parses double-quoted identifiers", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Contains(t, create.QuotedName.Value, "Test Func") -// }) - -// t.Run("parses case-sensitive identifiers", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create := doc.parseCreate(s, 0) - -// assert.Contains(t, create.QuotedName.Value, "TestFunc") -// }) -// } - -// func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { -// t.Run("parses array types", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "integer[]") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "integer[]", typ.BaseType) -// }) - -// t.Run("parses serial types", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "serial") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "serial", typ.BaseType) -// }) - -// t.Run("parses text type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "text") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "text", typ.BaseType) -// }) - -// t.Run("parses jsonb type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "jsonb") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "jsonb", typ.BaseType) -// }) - -// t.Run("parses uuid type", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "uuid") -// s.NextToken() - -// typ := doc.parseTypeExpression(s) - -// assert.Equal(t, "uuid", typ.BaseType) -// }) -// } - -// func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { -// t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { -// doc := &TSqlDocument{} -// s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") -// s.NextToken() -// s.NextNonWhitespaceCommentToken() - -// create1 := doc.parseCreate(s, 0) -// assert.Equal(t, "test1", create1.QuotedName.Value) - -// // Move to next statement -// s.NextNonWhitespaceCommentToken() -// s.NextNonWhitespaceCommentToken() - -// create2 := doc.parseCreate(s, 1) -// assert.Equal(t, "test2", create2.QuotedName.Value) -// }) -// } From af756289cbe612efd2c45babaa8f55e2610eb671 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 22:34:10 +0100 Subject: [PATCH 23/28] Updated tests. --- preprocess_test.go | 39 +---- sqlparser/pgsql_document_test.go | 254 +++++++++++++++++++++++++++++++ sqlparser/tsql_document.go | 9 +- sqlparser/tsql_document_test.go | 92 ----------- 4 files changed, 257 insertions(+), 137 deletions(-) create mode 100644 sqlparser/pgsql_document_test.go diff --git a/preprocess_test.go b/preprocess_test.go index e68d8bf..85132e6 100644 --- a/preprocess_test.go +++ b/preprocess_test.go @@ -58,10 +58,7 @@ func TestLineNumberInInput(t *testing.T) { func TestSchemaSuffixFromHash(t *testing.T) { t.Run("returns a unique hash", func(t *testing.T) { - doc := sqlparser.Document{ - Declares: []sqlparser.Declare{}, - } - + doc := sqlparser.NewDocumentFromExtension(".sql") value := SchemaSuffixFromHash(doc) require.Equal(t, value, SchemaSuffixFromHash(doc)) }) @@ -99,7 +96,7 @@ create procedure [code].Test2 as begin end }) t.Run("empty document has hash", func(t *testing.T) { - doc := sqlparser.Document{} + doc := sqlparser.NewDocumentFromExtension(".pgsql") suffix := SchemaSuffixFromHash(doc) assert.Len(t, suffix, 12) }) @@ -193,7 +190,6 @@ begin select 1 end `) - doc.Creates[0].Driver = &mssql.Driver{} result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) @@ -211,8 +207,6 @@ begin end; $$ language plpgsql; `) - doc.Creates[0].Driver = &stdlib.Driver{} - result, err := Preprocess(doc, "abc123", &stdlib.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -230,8 +224,6 @@ begin select @EnumStatus end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -252,8 +244,6 @@ begin select @EnumMulti end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -272,8 +262,6 @@ begin select @EnumUndeclared end `) - doc.Creates[0].Driver = &mssql.Driver{} - _, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.Error(t, err) @@ -287,34 +275,17 @@ end doc := sqlparser.ParseString("test.sql", ` create procedure [code].Test as begin end `) - doc.Creates[0].Driver = &mssql.Driver{} - _, err := Preprocess(doc, "abc]123", &mssql.Driver{}) require.Error(t, err) assert.Contains(t, err.Error(), "schemasuffix cannot contain") }) - t.Run("skips creates with empty body", func(t *testing.T) { - doc := sqlparser.Document{ - Creates: []sqlparser.Create{ - {Body: []sqlparser.Unparsed{}}, - }, - } - - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) - require.NoError(t, err) - assert.Empty(t, result.Batches) - }) - t.Run("handles multiple creates", func(t *testing.T) { doc := sqlparser.ParseString("test.sql", ` create procedure [code].Proc1 as begin select 1 end go create procedure [code].Proc2 as begin select 2 end `) - doc.Creates[0].Driver = &mssql.Driver{} - doc.Creates[1].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) assert.Len(t, result.Batches, 2) @@ -332,8 +303,6 @@ begin select @EnumA, @EnumB end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -354,8 +323,6 @@ begin select 1 end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) @@ -375,8 +342,6 @@ begin select @ConstValue, @GlobalSetting end `) - doc.Creates[0].Driver = &mssql.Driver{} - result, err := Preprocess(doc, "abc123", &mssql.Driver{}) require.NoError(t, err) require.Len(t, result.Batches, 1) diff --git a/sqlparser/pgsql_document_test.go b/sqlparser/pgsql_document_test.go new file mode 100644 index 0000000..2d3c8af --- /dev/null +++ b/sqlparser/pgsql_document_test.go @@ -0,0 +1,254 @@ +package sqlparser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { + t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + assert.Equal(t, "test_func", create.QuotedName.Value) + }) + + t.Run("parses PostgreSQL procedure", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "procedure", create.CreateType) + assert.Equal(t, "insert_data", create.QuotedName.Value) + }) + + t.Run("parses CREATE OR REPLACE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses schema-qualified name", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "test_func") + }) + + t.Run("parses RETURNS TABLE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("tracks dependencies with schema prefix", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + require.Len(t, create.DependsOn, 2) + }) + + t.Run("parses volatility categories", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses PARALLEL SAFE", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Types(t *testing.T) { + t.Run("parses composite type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type address_type as (street text, city text, zip varchar(10))") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses enum type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type mood as enum ('sad', 'ok', 'happy')") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) + + t.Run("parses range type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create type float_range as range (subtype = float8, subtype_diff = float8mi)") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "type", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Extensions(t *testing.T) { + t.Run("parses JSON functions PostgreSQL 17", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test() returns jsonb as $$ select json_serialize(data) from table1; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) + + t.Run("parses MERGE statement (PostgreSQL 15+)", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function do_merge() returns void as $$ merge into target using source on target.id = source.id when matched then update set value = source.value; $$ language sql") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Equal(t, "function", create.CreateType) + }) +} + +func TestDocument_PostgreSQL17_Identifiers(t *testing.T) { + t.Run("parses double-quoted identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "Test Func"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "Test Func") + }) + + t.Run("parses case-sensitive identifiers", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", `create function "TestFunc"() returns int as $$ begin return 1; end; $$ language plpgsql`) + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create := doc.parseCreate(s, 0) + + assert.Contains(t, create.QuotedName.Value, "TestFunc") + }) +} + +func TestDocument_PostgreSQL17_Datatypes(t *testing.T) { + t.Run("parses array types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "integer[]") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "integer[]", typ.BaseType) + }) + + t.Run("parses serial types", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "serial") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "serial", typ.BaseType) + }) + + t.Run("parses text type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "text") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "text", typ.BaseType) + }) + + t.Run("parses jsonb type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "jsonb") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "jsonb", typ.BaseType) + }) + + t.Run("parses uuid type", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "uuid") + s.NextToken() + + typ := doc.parseTypeExpression(s) + + assert.Equal(t, "uuid", typ.BaseType) + }) +} + +func TestDocument_PostgreSQL17_BatchSeparator(t *testing.T) { + t.Run("PostgreSQL uses semicolon not GO", func(t *testing.T) { + doc := &TSqlDocument{} + s := NewScanner("test.pgsql", "create function test1() returns int as $$ begin return 1; end; $$ language plpgsql; create function test2() returns int as $$ begin return 2; end; $$ language plpgsql;") + s.NextToken() + s.NextNonWhitespaceCommentToken() + + create1 := doc.parseCreate(s, 0) + assert.Equal(t, "test1", create1.QuotedName.Value) + + // Move to next statement + s.NextNonWhitespaceCommentToken() + s.NextNonWhitespaceCommentToken() + + create2 := doc.parseCreate(s, 1) + assert.Equal(t, "test2", create2.QuotedName.Value) + }) +} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 23d43a4..2aa35a9 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -5,7 +5,6 @@ import ( "sort" "strings" - "github.com/jackc/pgx/v5/stdlib" mssql "github.com/microsoft/go-mssqldb" ) @@ -315,13 +314,7 @@ func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { case "create": // should be start of create procedure or create function... c := doc.parseCreate(s, createCountInBatch) - - if strings.HasSuffix(string(s.file), ".sql") { - c.Driver = &mssql.Driver{} - } - if strings.HasSuffix(string(s.file), ".pgsql") { - c.Driver = &stdlib.Driver{} - } + c.Driver = &mssql.Driver{} // *prepend* what we saw before getting to the 'create' createCountInBatch++ diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index 0d5256c..b44d575 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -522,95 +522,3 @@ func TestDocument_recoverToNextStatementCopying(t *testing.T) { assert.Equal(t, "declare", s.ReservedWord()) }) } - -func TestDocument_PostgreSQL17_parseCreate(t *testing.T) { - t.Run("parses PostgreSQL function with dollar quoting", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - assert.Equal(t, "test_func", create.QuotedName.Value) - }) - - t.Run("parses PostgreSQL procedure", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create procedure insert_data(a integer, b integer) language sql as $$ insert into tbl values (a, b); $$") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "procedure", create.CreateType) - assert.Equal(t, "insert_data", create.QuotedName.Value) - }) - - t.Run("parses CREATE OR REPLACE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create or replace function test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses schema-qualified name", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function public.test_func() returns int as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Contains(t, create.QuotedName.Value, "test_func") - }) - - t.Run("parses RETURNS TABLE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function get_users() returns table(id int, name text) as $$ select id, name from users; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("tracks dependencies with schema prefix", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test() returns int as $$ select * from public.table1 join public.table2 on table1.id = table2.id; $$ language sql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - require.Len(t, create.DependsOn, 2) - }) - - t.Run("parses volatility categories", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int immutable as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) - - t.Run("parses PARALLEL SAFE", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.pgsql", "create function test_func() returns int parallel safe as $$ begin return 1; end; $$ language plpgsql") - s.NextToken() - s.NextNonWhitespaceCommentToken() - - create := doc.parseCreate(s, 0) - - assert.Equal(t, "function", create.CreateType) - }) -} From 6916d342751a845b5f9a05909b4bb5dfedbb4ab8 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 22:49:56 +0100 Subject: [PATCH 24/28] Simplified Document interface. Created Pragma struct. --- sqlparser/document.go | 4 +--- sqlparser/parser.go | 10 +++------ sqlparser/pgsql_document.go | 4 ++++ sqlparser/pragma.go | 41 +++++++++++++++++++++++++++++++++++ sqlparser/tsql_document.go | 43 ++++++++++++++----------------------- 5 files changed, 65 insertions(+), 37 deletions(-) create mode 100644 sqlparser/pragma.go diff --git a/sqlparser/document.go b/sqlparser/document.go index 21839ec..94b9020 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -19,9 +19,7 @@ type Document interface { PragmaIncludeIf() []string Include(other Document) Sort() - ParsePragmas(s *Scanner) - ParseBatch(s *Scanner, isFirst bool) (hasMore bool) - + Parse(s *Scanner) error WithoutPos() Document } diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 1a49a08..c15e25f 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -61,14 +61,10 @@ func Parse(s *Scanner, result Document) { // // `s` will typically never be positioned on whitespace except in // whitespace-preserving parsing - - filepath.Ext(s.input) - s.NextNonWhitespaceToken() - result.ParsePragmas(s) - hasMore := result.ParseBatch(s, true) - for hasMore { - hasMore = result.ParseBatch(s, false) + err := result.Parse(s) + if err != nil { + panic(fmt.Sprintf("failed to parse document: %s: %e", s.file, err)) } return } diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 9aacf98..24ff627 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -10,6 +10,10 @@ func (d PGSqlDocument) HasErrors() bool { return len(d.errors) > 0 } +func (d *PGSqlDocument) Parse(s *Scanner) error { + return nil +} + func (d PGSqlDocument) Creates() []Create { return d.creates } diff --git a/sqlparser/pragma.go b/sqlparser/pragma.go new file mode 100644 index 0000000..f0ee990 --- /dev/null +++ b/sqlparser/pragma.go @@ -0,0 +1,41 @@ +package sqlparser + +import ( + "fmt" + "strings" +) + +type Pragma struct { + pragmas []string +} + +func (d Pragma) PragmaIncludeIf() []string { + return d.pragmas +} + +func (d *Pragma) parseSinglePragma(s *Scanner) error { + pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) + if pragma == "" { + return nil + } + parts := strings.Split(pragma, " ") + + if len(parts) != 2 || parts[0] != "include-if" { + return fmt.Errorf("Illegal pragma: %s", s.Token()) + } + + d.pragmas = append(d.pragmas, strings.Split(parts[1], ",")...) + return nil +} + +func (d *Pragma) ParsePragmas(s *Scanner) error { + for s.TokenType() == PragmaToken { + err := d.parseSinglePragma(s) + if err != nil { + return err + } + s.NextNonWhitespaceToken() + } + + return nil +} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 2aa35a9..1038de7 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -13,6 +13,8 @@ type TSqlDocument struct { creates []Create declares []Declare errors []Error + + Pragma } func (d TSqlDocument) HasErrors() bool { @@ -30,8 +32,19 @@ func (d TSqlDocument) Declares() []Declare { func (d TSqlDocument) Errors() []Error { return d.errors } -func (d TSqlDocument) PragmaIncludeIf() []string { - return d.pragmaIncludeIf + +func (d *TSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.addError(s, err.Error()) + } + + hasMore := d.parseBatch(s, true) + for hasMore { + hasMore = d.parseBatch(s, false) + } + + return nil } func (d *TSqlDocument) Sort() { @@ -79,30 +92,6 @@ func (d *TSqlDocument) Include(other Document) { d.errors = append(d.errors, other.Errors()...) } -func (d *TSqlDocument) parseSinglePragma(s *Scanner) { - pragma := strings.TrimSpace(strings.TrimPrefix(s.Token(), "--sqlcode:")) - if pragma == "" { - return - } - parts := strings.Split(pragma, " ") - if len(parts) != 2 { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - if parts[0] != "include-if" { - d.addError(s, "Illegal pragma: "+s.Token()) - return - } - d.pragmaIncludeIf = append(d.pragmaIncludeIf, strings.Split(parts[1], ",")...) -} - -func (d *TSqlDocument) ParsePragmas(s *Scanner) { - for s.TokenType() == PragmaToken { - d.parseSinglePragma(s) - s.NextNonWhitespaceToken() - } -} - func (d TSqlDocument) Empty() bool { return len(d.creates) == 0 || len(d.declares) == 0 } @@ -272,7 +261,7 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { } } -func (doc *TSqlDocument) ParseBatch(s *Scanner, isFirst bool) (hasMore bool) { +func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { var nodes []Unparsed var docstring []PosString newLineEncounteredInDocstring := false From 1a915567e2219a51e6b67b86e48f128cc92e544b Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 11 Dec 2025 23:21:53 +0100 Subject: [PATCH 25/28] [wip] pgsql document parsing --- sqlparser/pgsql_document.go | 533 +++++++++++++++++++++++++++++++++++- sqlparser/scanner.go | 2 + 2 files changed, 525 insertions(+), 10 deletions(-) diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 24ff627..88ae487 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -1,9 +1,16 @@ package sqlparser +import ( + "fmt" + + "github.com/jackc/pgx/v5/stdlib" +) + type PGSqlDocument struct { - pragmaIncludeIf []string - creates []Create - errors []Error + creates []Create + errors []Error + + Pragma } func (d PGSqlDocument) HasErrors() bool { @@ -11,6 +18,11 @@ func (d PGSqlDocument) HasErrors() bool { } func (d *PGSqlDocument) Parse(s *Scanner) error { + err := d.ParsePragmas(s) + if err != nil { + d.errors = append(d.errors, Error{s.Start(), err.Error()}) + } + return nil } @@ -18,6 +30,7 @@ func (d PGSqlDocument) Creates() []Create { return d.creates } +// Not yet implemented func (d PGSqlDocument) Declares() []Declare { return nil } @@ -25,9 +38,6 @@ func (d PGSqlDocument) Declares() []Declare { func (d PGSqlDocument) Errors() []Error { return d.errors } -func (d PGSqlDocument) PragmaIncludeIf() []string { - return d.pragmaIncludeIf -} func (d PGSqlDocument) Empty() bool { return len(d.creates) == 0 @@ -41,14 +51,517 @@ func (d PGSqlDocument) Include(other Document) { } -func (d PGSqlDocument) ParsePragmas(s *Scanner) { +func (d PGSqlDocument) WithoutPos() Document { + return &PGSqlDocument{} +} + +// No GO batch separator: +// +// PostgreSQL uses semicolons (;) to separate statements, not GO. +// Multiple CREATE statements can exist in the same file. +// +// No top-level DECLARE: +// +// In PostgreSQL, DECLARE is only used inside function/procedure bodies within BEGIN...END blocks, not as top-level batch statements. +// +// Multiple CREATEs per batch: +// +// Unlike T-SQL which requires procedures/functions to be alone in a batch, PostgreSQL allows multiple CREATE statements separated by semicolons. +// +// Semicolon handling: +// +// The semicolon is a statement terminator, not a batch separator, so parsing continues after encountering one. +// +// Dollar quoting: +// +// PostgreSQL uses $$ or $tag$ for quoting function bodies instead of BEGIN...END (this would be handled in parseCreate). +// +// CREATE OR REPLACE: +// +// PostgreSQL commonly uses CREATE OR REPLACE which would need special handling in parseCreate. +// +// Schema qualification: +// +// PostgreSQL uses schema.object notation rather than [schema].[object]. +func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + var nodes []Unparsed + var docstring []PosString + newLineEncounteredInDocstring := false + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + nodes = append(nodes, CreateUnparsed(s)) + // do not reset docstring for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + docstring = nil + } + s.NextToken() + case SinglelineCommentToken: + // Build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + docstring = append(docstring, PosString{s.Start(), s.Token()}) + nodes = append(nodes, CreateUnparsed(s)) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + switch s.ReservedWord() { + case "declare": + // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // DECLARE is only used inside function/procedure bodies + if isFirst { + doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") + } + nodes = append(nodes, CreateUnparsed(s)) + s.NextToken() + docstring = nil + case "create": + // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + createStart := len(doc.creates) + c := doc.parseCreate(s, createStart) + c.Driver = &stdlib.Driver{} + + // Prepend any leading comments/whitespace + c.Body = append(nodes, c.Body...) + c.Docstring = docstring + doc.creates = append(doc.creates, c) + + // Reset for next statement + nodes = nil + docstring = nil + newLineEncounteredInDocstring = false + default: + doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) + s.NextToken() + docstring = nil + } + case SemicolonToken: + // PostgreSQL uses semicolons as statement terminators + // Multiple CREATE statements can exist in same file + nodes = append(nodes, CreateUnparsed(s)) + s.NextToken() + // Continue parsing - don't return like T-SQL does with GO + case BatchSeparatorToken: + // PostgreSQL doesn't use GO batch separators + // Q: Do we want to use GO batch separators as a feature of sqlcode? + doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") + s.NextToken() + docstring = nil + default: + doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) + s.NextToken() + docstring = nil + } + } } -func (d PGSqlDocument) WithoutPos() Document { - return &PGSqlDocument{} +// parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) +// Position is *on* the CREATE token. +// +// PostgreSQL CREATE syntax differences from T-SQL: +// - Supports CREATE OR REPLACE for functions/procedures +// - Uses dollar quoting ($$...$$) or $tag$...$tag$ for function bodies +// - Schema qualification uses dot notation: schema.function_name +// - Double-quoted identifiers preserve case: "MyFunction" +// - Function parameters use different syntax: func(param1 type1, param2 type2) +// - RETURNS clause specifies return type +// - LANGUAGE clause (plpgsql, sql, etc.) is required +// - Function characteristics: IMMUTABLE, STABLE, VOLATILE, PARALLEL SAFE, etc. +// +// We parse until we hit a semicolon or EOF, tracking dependencies on other objects. +func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { + var body []Unparsed + pos := s.Start() + + // Copy the CREATE token + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Check for OR REPLACE + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "replace" { + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + } else { + doc.addError(s, "Expected 'REPLACE' after 'OR'") + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + } + + // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) + if s.TokenType() != ReservedWordToken { + doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + + createType := s.ReservedWord() + result.CreateType = createType + CopyToken(s, &body) + s.NextNonWhitespaceCommentToken() + + // Validate supported CREATE types + switch createType { + case "function", "procedure", "type": + // Supported types + default: + doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) + doc.recoverToNextStatementCopying(s, &body) + result.Body = body + return + } + + // Parse the object name (with optional schema qualification) + // objectName := doc.parseQualifiedName(s, &body) + // if objectName == "" { + // doc.addError(s, "Expected object name after CREATE "+createType) + // doc.recoverToNextStatementCopying(s, &body) + // result.Body = body + // return + // } + + // result.QuotedName = PosString{pos, objectName} + + // Insist on [code]. + if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { + doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) + doc.recoverToNextStatementCopying(s, &result.Body) + return + } + result.QuotedName = doc.parseCodeschemaName(s, &result.Body) + if result.QuotedName.String() == "" { + return + } + + // Parse function/procedure signature or type definition + switch createType { + case "function", "procedure": + doc.parseFunctionSignature(s, &body, &result) + case "type": + doc.parseTypeDefinition(s, &body, &result) + } + + // Parse the rest of the CREATE statement body until semicolon or EOF + doc.parseCreateBody(s, &body, &result) + + result.Body = body + return +} + +// parseQualifiedName parses schema-qualified or simple object names +// Supports: simple_name, schema.name, "Quoted Name", schema."Quoted Name" +func (doc *PGSqlDocument) parseQualifiedName(s *Scanner, body *[]Unparsed) string { + var nameParts []string + + for { + switch s.TokenType() { + case UnquotedIdentifierToken: + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + case QuotedIdentifierToken: + // PostgreSQL uses double quotes for case-sensitive identifiers + nameParts = append(nameParts, s.Token()) + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + default: + if len(nameParts) == 0 { + return "" + } + // Return the last part as the object name (without schema) + return nameParts[len(nameParts)-1] + } + + // Check for dot separator (schema.object) + if s.TokenType() == DotToken { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + continue + } + + break + } + + if len(nameParts) == 0 { + return "" + } + return nameParts[len(nameParts)-1] +} + +// parseFunctionSignature parses function/procedure parameters and RETURNS clause +func (doc *PGSqlDocument) parseFunctionSignature(s *Scanner, body *[]Unparsed, result *Create) { + // Expect opening parenthesis for parameters + if s.TokenType() != LeftParenToken { + doc.addError(s, "Expected '(' for function parameters") + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Parse parameters until closing parenthesis + parenDepth := 1 + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken: + doc.addError(s, "Unexpected EOF in function parameters") + return + case LeftParenToken: + parenDepth++ + CopyToken(s, body) + s.NextToken() + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + case SemicolonToken: + doc.addError(s, "Unexpected semicolon in function parameters") + return + default: + CopyToken(s, body) + s.NextToken() + } + } + + s.SkipWhitespaceComments() + + // Parse RETURNS clause (for functions, not procedures) + if result.CreateType == "function" { + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "returns" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle RETURNS TABLE(...) + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "table" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + if s.TokenType() == LeftParenToken { + doc.parseReturnTable(s, body) + } + } else { + // Parse simple return type + doc.parseTypeExpression(s, body) + } + } + } +} + +// parseReturnTable parses RETURNS TABLE(...) syntax +func (doc *PGSqlDocument) parseReturnTable(s *Scanner, body *[]Unparsed) { + parenDepth := 0 + for { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + CopyToken(s, body) + s.NextToken() + if parenDepth == 0 { + return + } + continue + } + CopyToken(s, body) + s.NextToken() + } +} + +// parseTypeExpression parses PostgreSQL type expressions +// Supports: int, integer, text, varchar(n), numeric(p,s), arrays (int[]), etc. +func (doc *PGSqlDocument) parseTypeExpression(s *Scanner, body *[]Unparsed) { + // Parse base type + if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != ReservedWordToken { + return + } + + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Handle array notation: type[] + // if s.TokenType() == LeftBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + + // if s.TokenType() == RightBracketToken { + // CopyToken(s, body) + // s.NextNonWhitespaceCommentToken() + // } + // } + + // Handle type parameters: varchar(100), numeric(10,2) + if s.TokenType() == LeftParenToken { + parenDepth := 1 + CopyToken(s, body) + s.NextToken() + + for parenDepth > 0 { + switch s.TokenType() { + case EOFToken, SemicolonToken: + return + case LeftParenToken: + parenDepth++ + case RightParenToken: + parenDepth-- + } + CopyToken(s, body) + s.NextToken() + } + } +} + +// parseTypeDefinition parses CREATE TYPE syntax +// Supports: ENUM, composite types, range types +func (doc *PGSqlDocument) parseTypeDefinition(s *Scanner, body *[]Unparsed, result *Create) { + // TYPE definitions use AS keyword + if s.TokenType() == ReservedWordToken && s.ReservedWord() == "as" { + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + + // Check for ENUM, RANGE, or composite type + if s.TokenType() == ReservedWordToken { + typeKind := s.ReservedWord() + switch typeKind { + case "enum", "range": + CopyToken(s, body) + s.NextNonWhitespaceCommentToken() + } + } + } } -func (d PGSqlDocument) ParseBatch(s *Scanner, isFirst bool) bool { +// parseCreateBody parses the body of a CREATE statement +// Handles dollar-quoted strings, tracks dependencies, continues until semicolon/EOF +func (doc *PGSqlDocument) parseCreateBody(s *Scanner, body *[]Unparsed, result *Create) { + dollarQuoteDepth := 0 + var currentDollarTag string + + for { + switch s.TokenType() { + case EOFToken: + return + case SemicolonToken: + // Statement terminator - we're done + CopyToken(s, body) + s.NextToken() + return + case DollarQuotedStringStartToken: + // PostgreSQL dollar quoting: $$...$$ or $tag$...$tag$ + currentDollarTag = s.Token() + dollarQuoteDepth++ + CopyToken(s, body) + s.NextToken() + case DollarQuotedStringEndToken: + if s.Token() == currentDollarTag { + dollarQuoteDepth-- + } + CopyToken(s, body) + s.NextToken() + if dollarQuoteDepth == 0 { + currentDollarTag = "" + } + case UnquotedIdentifierToken, QuotedIdentifierToken: + // Track dependencies on tables/views/functions + // In PostgreSQL, identifiers can be qualified: schema.object + identifier := s.Token() + + // Check if this might be a dependency (after FROM, JOIN, etc.) + if doc.mightBeDependency(s) { + // Extract just the object name (without schema prefix) + objectName := doc.extractObjectName(identifier) + result.DependsOn = append(result.DependsOn, PosString{s.Start(), objectName}) + } + + CopyToken(s, body) + s.NextToken() + default: + CopyToken(s, body) + s.NextToken() + } + } +} + +// mightBeDependency checks if current context suggests a table/view/function reference +func (doc *PGSqlDocument) mightBeDependency(s *Scanner) bool { + // Simple heuristic: look back for FROM, JOIN, INTO, etc. + // This would need to track parse context for accurate dependency detection + return false // Placeholder - implement context-aware dependency tracking +} + +// extractObjectName extracts object name from schema-qualified identifier +func (doc *PGSqlDocument) extractObjectName(identifier string) string { + // Handle schema.object notation + // For now, return as-is; proper implementation would split on dot + return identifier +} + +// recoverToNextStatementCopying recovers from parse errors by skipping to next statement +func (doc *PGSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "create", "drop", "alter": + return + } + case EOFToken, SemicolonToken: + return + default: + CopyToken(s, target) + } + } +} + +func (doc *PGSqlDocument) addError(s *Scanner, err string) { + doc.errors = append(doc.errors, Error{ + s.Start(), err, + }) +} + +func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { + // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) + doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") + doc.recoverToNextStatement(s) return false } + +func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { + // PostgreSQL doesn't use GO batch separators + doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") + s.NextToken() +} + +func (doc *PGSqlDocument) recoverToNextStatement(s *Scanner) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + switch s.ReservedWord() { + case "create", "drop", "alter": + return + } + case EOFToken, SemicolonToken: + return + } + } +} diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index a5fb75a..9e6c27f 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -565,6 +565,8 @@ var reservedWords = map[string]struct{}{ "writetext": struct{}{}, "exit": struct{}{}, "proc": struct{}{}, + // pgsql + "replace": struct{}{}, } // apparently 'within group' is also reserved but dropping that.. From fb5641489745059add55328b513fc20d7844df3b Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 16 Dec 2025 19:23:28 +0100 Subject: [PATCH 26/28] Simplify the interfaces for parsing a SQL document. --- sqlparser/document.go | 92 +++++++++++++++++ sqlparser/dom.go | 24 ----- sqlparser/nodes.go | 91 +++++++++++++++++ sqlparser/pgsql_document.go | 95 ++++++++---------- sqlparser/tsql_document.go | 191 ++++++++---------------------------- sqlparser/unparsed.go | 25 +++++ 6 files changed, 287 insertions(+), 231 deletions(-) create mode 100644 sqlparser/nodes.go create mode 100644 sqlparser/unparsed.go diff --git a/sqlparser/document.go b/sqlparser/document.go index 94b9020..1721875 100644 --- a/sqlparser/document.go +++ b/sqlparser/document.go @@ -1,7 +1,9 @@ package sqlparser import ( + "fmt" "path/filepath" + "slices" "strings" ) @@ -41,3 +43,93 @@ func NewDocumentFromExtension(extension string) Document { panic("unhandled document type: " + extension) } } + +// parseCodeschemaName parses `[code] . something`, and returns `something` +// in quoted form (`[something]`). Also copy to `target`. Empty string on error. +// Note: To follow conventions, consume one extra token at the end even if we know +// it fill not be consumed by this function... +func ParseCodeschemaName(s *Scanner, target *[]Unparsed, statementTokens []string) (PosString, error) { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + if s.TokenType() != DotToken { + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code] must be followed by '.'") + } + + CopyToken(s, target) + + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case UnquotedIdentifierToken: + // To get something uniform for comparison, quote all names + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} + NextTokenCopyingWhitespace(s, target) + return result, nil + case QuotedIdentifierToken: + CopyToken(s, target) + result := PosString{Pos: s.Start(), Value: s.Token()} + NextTokenCopyingWhitespace(s, target) + return result, nil + default: + RecoverToNextStatementCopying(s, target, statementTokens) + return PosString{Value: ""}, fmt.Errorf("[code]. must be followed an identifier") + } +} + +// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered +// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace +// token, and target is either unmodified or filled with some whitespace nodes. +func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { + for { + tt := s.NextToken() + switch tt { + case EOFToken, BatchSeparatorToken: + // do not copy + return + case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: + // copy, and loop around + CopyToken(s, target) + continue + default: + return + } + } + +} + +func RecoverToNextStatementCopying(s *Scanner, target *[]Unparsed, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + default: + CopyToken(s, target) + } + } +} + +func RecoverToNextStatement(s *Scanner, StatementTokens []string) { + // We hit an unexpected token ... as an heuristic for continuing parsing, + // skip parsing until we hit a reserved word that starts a statement + // we recognize + for { + s.NextNonWhitespaceCommentToken() + switch s.TokenType() { + case ReservedWordToken: + if slices.Contains(StatementTokens, s.ReservedWord()) { + return + } + case EOFToken: + return + } + } +} diff --git a/sqlparser/dom.go b/sqlparser/dom.go index 0c72587..a75db38 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -5,21 +5,6 @@ import ( "strings" ) -type Unparsed struct { - Type TokenType - Start, Stop Pos - RawValue string -} - -func (u Unparsed) WithoutPos() Unparsed { - return Unparsed{ - Type: u.Type, - Start: Pos{}, - Stop: Pos{}, - RawValue: u.RawValue, - } -} - type Declare struct { Start Pos Stop Pos @@ -82,12 +67,3 @@ func (e Error) Error() string { func (e Error) WithoutPos() Error { return Error{Message: e.Message} } - -func CreateUnparsed(s *Scanner) Unparsed { - return Unparsed{ - Type: s.TokenType(), - Start: s.Start(), - Stop: s.Stop(), - RawValue: s.Token(), - } -} diff --git a/sqlparser/nodes.go b/sqlparser/nodes.go new file mode 100644 index 0000000..7b2033e --- /dev/null +++ b/sqlparser/nodes.go @@ -0,0 +1,91 @@ +package sqlparser + +import ( + "fmt" +) + +type Nodes struct { + Nodes []Unparsed + DocString []PosString + CreateStatements int + TokenHandlers map[string]func(*Scanner, *Nodes) bool + Errors []Error + BatchSeparatorToken TokenType +} + +func (n *Nodes) Create(s *Scanner) { + n.Nodes = append(n.Nodes, CreateUnparsed(s)) +} + +func (n *Nodes) HasErrors() bool { + return len(n.Errors) > 0 +} + +// Agnostic parser that handles comments, whitespace, and reserved words +func (n *Nodes) Parse(s *Scanner) bool { + newLineEncounteredInDocstring := false + + for { + tt := s.TokenType() + switch tt { + case EOFToken: + return false + case WhitespaceToken, MultilineCommentToken: + n.Create(s) + // do not reset token for a single trailing newline + t := s.Token() + if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + newLineEncounteredInDocstring = true + } else { + n.DocString = nil + } + s.NextToken() + case SinglelineCommentToken: + // We build up a list of single line comments for the "docstring"; + // it is reset whenever we encounter something else + n.DocString = append(n.DocString, PosString{s.Start(), s.Token()}) + n.Create(s) + newLineEncounteredInDocstring = false + s.NextToken() + case ReservedWordToken: + token := s.ReservedWord() + handler, exists := n.TokenHandlers[token] + if !exists { + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Expected , got: %s", token), + }) + s.NextToken() + } else { + if handler(s, n) { + // regardless of errors, go on and parse as far as we get... + return true + } + } + case BatchSeparatorToken: + // TODO + errorEmitted := false + for { + switch s.NextToken() { + case WhitespaceToken: + continue + case MalformedBatchSeparatorToken: + if !errorEmitted { + n.Errors = append(n.Errors, Error{ + s.Start(), "`go` should be alone on a line without any comments", + }) + errorEmitted = true + } + continue + default: + return true + } + } + default: + n.Errors = append(n.Errors, Error{ + s.Start(), fmt.Sprintf("Unexpected token: %s", s.Token()), + }) + s.NextToken() + n.DocString = nil + } + } +} diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 88ae487..181e4a0 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -6,6 +6,8 @@ import ( "github.com/jackc/pgx/v5/stdlib" ) +var PGSQLStatementTokens = []string{"create"} + type PGSqlDocument struct { creates []Create errors []Error @@ -84,6 +86,30 @@ func (d PGSqlDocument) WithoutPos() Document { // // PostgreSQL uses schema.object notation rather than [schema].[object]. func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { + nodes := &Nodes{ + TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ + "create": func(s *Scanner, n *Nodes) bool { + // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + c := doc.parseCreate(s, n.CreateStatements) + c.Driver = &stdlib.Driver{} + + // Prepend any leading comments/whitespace + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString + doc.creates = append(doc.creates, c) + + return false + }, + }, + } + + hasMore = nodes.Parse(s) + if nodes.HasErrors() { + doc.errors = append(doc.errors, nodes.Errors...) + } + + return hasMore + var nodes []Unparsed var docstring []PosString newLineEncounteredInDocstring := false @@ -177,13 +203,14 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { // We parse until we hit a semicolon or EOF, tracking dependencies on other objects. func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result Create) { var body []Unparsed - pos := s.Start() // Copy the CREATE token CopyToken(s, &body) s.NextNonWhitespaceCommentToken() // Check for OR REPLACE + // NOTE: "or replace" doesn't make sense within sqlcode as this will be created within a new + // schema. if s.TokenType() == ReservedWordToken && s.ReservedWord() == "or" { CopyToken(s, &body) s.NextNonWhitespaceCommentToken() @@ -193,7 +220,7 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul s.NextNonWhitespaceCommentToken() } else { doc.addError(s, "Expected 'REPLACE' after 'OR'") - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } @@ -202,7 +229,7 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul // Parse the object type (FUNCTION, PROCEDURE, TYPE, etc.) if s.TokenType() != ReservedWordToken { doc.addError(s, "Expected object type after CREATE (e.g., FUNCTION, PROCEDURE, TYPE)") - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } @@ -218,29 +245,23 @@ func (doc *PGSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (resul // Supported types default: doc.addError(s, fmt.Sprintf("Unsupported CREATE type for PostgreSQL: %s", createType)) - doc.recoverToNextStatementCopying(s, &body) + RecoverToNextStatementCopying(s, &body, PGSQLStatementTokens) result.Body = body return } - // Parse the object name (with optional schema qualification) - // objectName := doc.parseQualifiedName(s, &body) - // if objectName == "" { - // doc.addError(s, "Expected object name after CREATE "+createType) - // doc.recoverToNextStatementCopying(s, &body) - // result.Body = body - // return - // } - - // result.QuotedName = PosString{pos, objectName} - - // Insist on [code]. + // Insist on [code] to provide the ability for sqlcode to patch function bodies + // with references to other sqlcode objects. if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { doc.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - doc.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, PGSQLStatementTokens) return } - result.QuotedName = doc.parseCodeschemaName(s, &result.Body) + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, PGSQLStatementTokens) + if err != nil { + doc.addError(s, err.Error()) + } if result.QuotedName.String() == "" { return } @@ -510,24 +531,6 @@ func (doc *PGSqlDocument) extractObjectName(identifier string) string { return identifier } -// recoverToNextStatementCopying recovers from parse errors by skipping to next statement -func (doc *PGSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "create", "drop", "alter": - return - } - case EOFToken, SemicolonToken: - return - default: - CopyToken(s, target) - } - } -} - func (doc *PGSqlDocument) addError(s *Scanner, err string) { doc.errors = append(doc.errors, Error{ s.Start(), err, @@ -538,7 +541,7 @@ func (doc *PGSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { // PostgreSQL doesn't have top-level DECLARE batches like T-SQL // DECLARE is only used inside function/procedure bodies (in BEGIN...END blocks) doc.addError(s, "PostgreSQL does not support top-level DECLARE statements outside of function bodies") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, PGSQLStatementTokens) return false } @@ -547,21 +550,3 @@ func (doc *PGSqlDocument) parseBatchSeparator(s *Scanner) { doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons") s.NextToken() } - -func (doc *PGSqlDocument) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "create", "drop", "alter": - return - } - case EOFToken, SemicolonToken: - return - } - } -} diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 1038de7..8c17499 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -8,6 +8,8 @@ import ( mssql "github.com/microsoft/go-mssqldb" ) +var TSQLStatementTokens = []string{"create", "declare", "go"} + type TSqlDocument struct { pragmaIncludeIf []string creates []Create @@ -118,7 +120,7 @@ func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { t.Args = append(t.Args, "max") default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } s.NextNonWhitespaceCommentToken() @@ -131,7 +133,7 @@ func (doc *TSqlDocument) parseTypeExpression(s *Scanner) (t Type) { return default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } } @@ -156,7 +158,7 @@ loop: for { if s.TokenType() != VariableIdentifierToken { doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } @@ -179,7 +181,7 @@ loop: if s.TokenType() != EqualToken { doc.addError(s, "sqlcode constants needs to be assigned at once using =") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) } switch s.NextNonWhitespaceCommentToken() { @@ -194,7 +196,7 @@ loop: result = append(result, declare) default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) return } @@ -250,151 +252,50 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { doc.declares = append(doc.declares, d...) case tt == ReservedWordToken && s.ReservedWord() != "declare": doc.addError(s, "Only 'declare' allowed in this batch") - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) case tt == BatchSeparatorToken: doc.parseBatchSeparator(s) return true default: doc.unexpectedTokenError(s) - doc.recoverToNextStatement(s) + RecoverToNextStatement(s, TSQLStatementTokens) } } } func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - var createCountInBatch int - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset token for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // We build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": + nodes := &Nodes{ + TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ + "declare": func(s *Scanner, n *Nodes) bool { // First declare-statement; enter a mode where we assume all contents // of batch are declare statements if !isFirst { doc.addError(s, "'declare' statement only allowed in first batch") } + // regardless of errors, go on and parse as far as we get... return doc.parseDeclareBatch(s) - case "create": + }, + "create": func(s *Scanner, n *Nodes) bool { // should be start of create procedure or create function... - c := doc.parseCreate(s, createCountInBatch) + c := doc.parseCreate(s, n.CreateStatements) c.Driver = &mssql.Driver{} // *prepend* what we saw before getting to the 'create' - createCountInBatch++ - c.Body = append(nodes, c.Body...) - c.Docstring = docstring + n.CreateStatements++ + c.Body = append(n.Nodes, c.Body...) + c.Docstring = n.DocString doc.creates = append(doc.creates, c) - default: - doc.addError(s, "Expected 'declare' or 'create', got: "+s.ReservedWord()) - s.NextToken() - } - case BatchSeparatorToken: - doc.parseBatchSeparator(s) - return true - default: - doc.unexpectedTokenError(s) - s.NextToken() - docstring = nil - } + return false + }, + }, } -} - -func (d *TSqlDocument) recoverToNextStatementCopying(s *Scanner, target *[]Unparsed) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - default: - CopyToken(s, target) - } + hasMore = nodes.Parse(s) + if nodes.HasErrors() { + doc.errors = append(doc.errors, nodes.Errors...) } -} -func (d *TSqlDocument) recoverToNextStatement(s *Scanner) { - // We hit an unexpected token ... as an heuristic for continuing parsing, - // skip parsing until we hit a reserved word that starts a statement - // we recognize - for { - s.NextNonWhitespaceCommentToken() - switch s.TokenType() { - case ReservedWordToken: - switch s.ReservedWord() { - case "declare", "create", "go": - return - } - case EOFToken: - return - } - } -} - -// parseCodeschemaName parses `[code] . something`, and returns `something` -// in quoted form (`[something]`). Also copy to `target`. Empty string on error. -// Note: To follow conventions, consume one extra token at the end even if we know -// it fill not be consumed by this function... -func (d *TSqlDocument) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString { - CopyToken(s, target) - NextTokenCopyingWhitespace(s, target) - if s.TokenType() != DotToken { - d.addError(s, fmt.Sprintf("[code] must be followed by '.'")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } - CopyToken(s, target) - - NextTokenCopyingWhitespace(s, target) - switch s.TokenType() { - case UnquotedIdentifierToken: - // To get something uniform for comparison, quote all names - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: "[" + s.Token() + "]"} - NextTokenCopyingWhitespace(s, target) - return result - case QuotedIdentifierToken: - CopyToken(s, target) - result := PosString{Pos: s.Start(), Value: s.Token()} - NextTokenCopyingWhitespace(s, target) - return result - default: - d.addError(s, fmt.Sprintf("[code]. must be followed an identifier")) - d.recoverToNextStatementCopying(s, target) - return PosString{Value: ""} - } + return hasMore } // parseCreate parses anything that starts with "create". Position is @@ -417,12 +318,12 @@ func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result C createType := strings.ToLower(s.Token()) if !(createType == "procedure" || createType == "function" || createType == "type") { d.addError(s, fmt.Sprintf("sqlcode only supports creating procedures, functions or types; not `%s`", createType)) - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } if (createType == "procedure" || createType == "function") && createCountInBatch > 0 { d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } @@ -434,10 +335,14 @@ func (d *TSqlDocument) parseCreate(s *Scanner, createCountInBatch int) (result C // Insist on [code]. if s.TokenType() != QuotedIdentifierToken || s.Token() != "[code]" { d.addError(s, fmt.Sprintf("create %s must be followed by [code].", result.CreateType)) - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) return } - result.QuotedName = d.parseCodeschemaName(s, &result.Body) + var err error + result.QuotedName, err = ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } if result.QuotedName.String() == "" { return } @@ -471,7 +376,7 @@ tailloop: if (tt2 == ReservedWordToken && (s.ReservedWord() == "function" || s.ReservedWord() == "procedure")) || (tt2 == UnquotedIdentifierToken && s.TokenLower() == "type") { - d.recoverToNextStatementCopying(s, &result.Body) + RecoverToNextStatementCopying(s, &result.Body, TSQLStatementTokens) d.addError(s, "a procedure/function must be alone in a batch; use 'go' to split batches") return } @@ -488,7 +393,10 @@ tailloop: break tailloop case tt == QuotedIdentifierToken && s.Token() == "[code]": // Parse a dependency - dep := d.parseCodeschemaName(s, &result.Body) + dep, err := ParseCodeschemaName(s, &result.Body, TSQLStatementTokens) + if err != nil { + d.addError(s, err.Error()) + } found := false for _, existing := range result.DependsOn { if existing.Value == dep.Value { @@ -549,24 +457,3 @@ tailloop: }) return } - -// NextTokenCopyingWhitespace is like s.NextToken(), but if whitespace is encountered -// it is simply copied into `target`. Upon return, the scanner is located at a non-whitespace -// token, and target is either unmodified or filled with some whitespace nodes. -func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - return - } - } - -} diff --git a/sqlparser/unparsed.go b/sqlparser/unparsed.go new file mode 100644 index 0000000..45a85bf --- /dev/null +++ b/sqlparser/unparsed.go @@ -0,0 +1,25 @@ +package sqlparser + +type Unparsed struct { + Type TokenType + Start, Stop Pos + RawValue string +} + +func CreateUnparsed(s *Scanner) Unparsed { + return Unparsed{ + Type: s.TokenType(), + Start: s.Start(), + Stop: s.Stop(), + RawValue: s.Token(), + } +} + +func (u Unparsed) WithoutPos() Unparsed { + return Unparsed{ + Type: u.Type, + Start: Pos{}, + Stop: Pos{}, + RawValue: u.RawValue, + } +} From 5e29dc931c7efc5625ec2aca646bfde796549d83 Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Tue, 16 Dec 2025 19:24:21 +0100 Subject: [PATCH 27/28] Refactored pgsql document to use node parser. --- sqlparser/pgsql_document.go | 150 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index 181e4a0..f6e2e8f 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -110,81 +110,81 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { return hasMore - var nodes []Unparsed - var docstring []PosString - newLineEncounteredInDocstring := false - - for { - tt := s.TokenType() - switch tt { - case EOFToken: - return false - case WhitespaceToken, MultilineCommentToken: - nodes = append(nodes, CreateUnparsed(s)) - // do not reset docstring for a single trailing newline - t := s.Token() - if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { - newLineEncounteredInDocstring = true - } else { - docstring = nil - } - s.NextToken() - case SinglelineCommentToken: - // Build up a list of single line comments for the "docstring"; - // it is reset whenever we encounter something else - docstring = append(docstring, PosString{s.Start(), s.Token()}) - nodes = append(nodes, CreateUnparsed(s)) - newLineEncounteredInDocstring = false - s.NextToken() - case ReservedWordToken: - switch s.ReservedWord() { - case "declare": - // PostgreSQL doesn't have top-level DECLARE batches like T-SQL - // DECLARE is only used inside function/procedure bodies - if isFirst { - doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") - } - nodes = append(nodes, CreateUnparsed(s)) - s.NextToken() - docstring = nil - case "create": - // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. - createStart := len(doc.creates) - c := doc.parseCreate(s, createStart) - c.Driver = &stdlib.Driver{} - - // Prepend any leading comments/whitespace - c.Body = append(nodes, c.Body...) - c.Docstring = docstring - doc.creates = append(doc.creates, c) - - // Reset for next statement - nodes = nil - docstring = nil - newLineEncounteredInDocstring = false - default: - doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) - s.NextToken() - docstring = nil - } - case SemicolonToken: - // PostgreSQL uses semicolons as statement terminators - // Multiple CREATE statements can exist in same file - nodes = append(nodes, CreateUnparsed(s)) - s.NextToken() - // Continue parsing - don't return like T-SQL does with GO - case BatchSeparatorToken: - // PostgreSQL doesn't use GO batch separators - // Q: Do we want to use GO batch separators as a feature of sqlcode? - doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") - s.NextToken() - docstring = nil - default: - doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) - s.NextToken() - docstring = nil - } - } + // var nodes []Unparsed + // var docstring []PosString + // newLineEncounteredInDocstring := false + + // for { + // tt := s.TokenType() + // switch tt { + // case EOFToken: + // return false + // case WhitespaceToken, MultilineCommentToken: + // nodes = append(nodes, CreateUnparsed(s)) + // // do not reset docstring for a single trailing newline + // t := s.Token() + // if !newLineEncounteredInDocstring && (t == "\n" || t == "\r\n") { + // newLineEncounteredInDocstring = true + // } else { + // docstring = nil + // } + // s.NextToken() + // case SinglelineCommentToken: + // // Build up a list of single line comments for the "docstring"; + // // it is reset whenever we encounter something else + // docstring = append(docstring, PosString{s.Start(), s.Token()}) + // nodes = append(nodes, CreateUnparsed(s)) + // newLineEncounteredInDocstring = false + // s.NextToken() + // case ReservedWordToken: + // switch s.ReservedWord() { + // case "declare": + // // PostgreSQL doesn't have top-level DECLARE batches like T-SQL + // // DECLARE is only used inside function/procedure bodies + // if isFirst { + // doc.addError(s, "PostgreSQL 'declare' is used inside function bodies, not as top-level batch statements") + // } + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // docstring = nil + // case "create": + // // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. + // createStart := len(doc.creates) + // c := doc.parseCreate(s, createStart) + // c.Driver = &stdlib.Driver{} + + // // Prepend any leading comments/whitespace + // c.Body = append(nodes, c.Body...) + // c.Docstring = docstring + // doc.creates = append(doc.creates, c) + + // // Reset for next statement + // nodes = nil + // docstring = nil + // newLineEncounteredInDocstring = false + // default: + // doc.addError(s, "Expected 'create', got: "+s.ReservedWord()) + // s.NextToken() + // docstring = nil + // } + // case SemicolonToken: + // // PostgreSQL uses semicolons as statement terminators + // // Multiple CREATE statements can exist in same file + // nodes = append(nodes, CreateUnparsed(s)) + // s.NextToken() + // // Continue parsing - don't return like T-SQL does with GO + // case BatchSeparatorToken: + // // PostgreSQL doesn't use GO batch separators + // // Q: Do we want to use GO batch separators as a feature of sqlcode? + // doc.addError(s, "PostgreSQL does not use 'GO' batch separators; use semicolons instead") + // s.NextToken() + // docstring = nil + // default: + // doc.addError(s, fmt.Sprintf("Unexpected token in PostgreSQL document: %s", s.Token())) + // s.NextToken() + // docstring = nil + // } + // } } // parseCreate parses PostgreSQL CREATE statements (FUNCTION, PROCEDURE, TYPE, etc.) From 6e114a937c7bfd2637435868655555a7b4d7e13e Mon Sep 17 00:00:00 2001 From: Kyle Simukka Date: Thu, 18 Dec 2025 20:21:29 +0100 Subject: [PATCH 28/28] [wip] --- sqlparser/{nodes.go => batch.go} | 10 +-- sqlparser/document_test.go | 95 ++++++++++++++++++++++++++ sqlparser/node_test.go | 1 + sqlparser/pgsql_document.go | 12 ++-- sqlparser/scanner.go | 13 ++-- sqlparser/tokentype.go | 7 ++ sqlparser/tsql_document.go | 15 ++-- sqlparser/tsql_document_test.go | 114 ++----------------------------- 8 files changed, 135 insertions(+), 132 deletions(-) rename sqlparser/{nodes.go => batch.go} (91%) create mode 100644 sqlparser/node_test.go diff --git a/sqlparser/nodes.go b/sqlparser/batch.go similarity index 91% rename from sqlparser/nodes.go rename to sqlparser/batch.go index 7b2033e..25504f5 100644 --- a/sqlparser/nodes.go +++ b/sqlparser/batch.go @@ -4,25 +4,25 @@ import ( "fmt" ) -type Nodes struct { +type Batch struct { Nodes []Unparsed DocString []PosString CreateStatements int - TokenHandlers map[string]func(*Scanner, *Nodes) bool + TokenHandlers map[string]func(*Scanner, *Batch) bool Errors []Error BatchSeparatorToken TokenType } -func (n *Nodes) Create(s *Scanner) { +func (n *Batch) Create(s *Scanner) { n.Nodes = append(n.Nodes, CreateUnparsed(s)) } -func (n *Nodes) HasErrors() bool { +func (n *Batch) HasErrors() bool { return len(n.Errors) > 0 } // Agnostic parser that handles comments, whitespace, and reserved words -func (n *Nodes) Parse(s *Scanner) bool { +func (n *Batch) Parse(s *Scanner) bool { newLineEncounteredInDocstring := false for { diff --git a/sqlparser/document_test.go b/sqlparser/document_test.go index 8e497e6..ec011c5 100644 --- a/sqlparser/document_test.go +++ b/sqlparser/document_test.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -54,3 +55,97 @@ func TestNewDocumentFromExtension(t *testing.T) { require.NotEqual(t, sqlDoc, pgsqlDoc) }) } + +func TestDocument_parseCodeschemaName(t *testing.T) { + t.Run("parses unquoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + assert.Equal(t, "[TestProc]", result.Value) + assert.NotEmpty(t, target) + }) + + t.Run("parses quoted identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].[Test Proc]") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.NoError(t, err) + + assert.Equal(t, "[Test Proc]", result.Value) + }) + + t.Run("errors on missing dot", func(t *testing.T) { + s := NewScanner("test.sql", "[code] TestProc") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + assert.Error(t, err) + + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed by '.'") + }) + + t.Run("errors on missing identifier", func(t *testing.T) { + s := NewScanner("test.sql", "[code].123") + s.NextToken() + var target []Unparsed + + result, err := ParseCodeschemaName(s, &target, nil) + + assert.Error(t, err) + assert.Equal(t, "", result.Value) + assert.ErrorContains(t, err, "must be followed an identifier") + }) +} + +func TestDocument_recoverToNextStatement(t *testing.T) { + t.Run("recovers to declare", func(t *testing.T) { + s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") + s.NextToken() + + RecoverToNextStatement(s, []string{"declare"}) + + fmt.Printf("%#v\n", s) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "declare", s.ReservedWord()) + }) + + t.Run("recovers to create", func(t *testing.T) { + s := NewScanner("test.sql", "bad stuff create procedure") + s.NextToken() + + RecoverToNextStatement(s, []string{"create"}) + + assert.Equal(t, ReservedWordToken, s.TokenType()) + assert.Equal(t, "create", s.ReservedWord()) + }) + + t.Run("stops at EOF", func(t *testing.T) { + s := NewScanner("test.sql", "no keywords") + s.NextToken() + + RecoverToNextStatement(s, []string{}) + + assert.Equal(t, EOFToken, s.TokenType()) + }) +} + +func TestDocument_recoverToNextStatementCopying(t *testing.T) { + t.Run("copies tokens while recovering", func(t *testing.T) { + s := NewScanner("test.sql", "bad token declare") + s.NextToken() + var target []Unparsed + + RecoverToNextStatementCopying(s, &target, []string{"declare"}) + + assert.NotEmpty(t, target) + assert.Equal(t, "declare", s.ReservedWord()) + }) +} diff --git a/sqlparser/node_test.go b/sqlparser/node_test.go new file mode 100644 index 0000000..04173c6 --- /dev/null +++ b/sqlparser/node_test.go @@ -0,0 +1 @@ +package sqlparser diff --git a/sqlparser/pgsql_document.go b/sqlparser/pgsql_document.go index f6e2e8f..e97ec32 100644 --- a/sqlparser/pgsql_document.go +++ b/sqlparser/pgsql_document.go @@ -86,9 +86,9 @@ func (d PGSqlDocument) WithoutPos() Document { // // PostgreSQL uses schema.object notation rather than [schema].[object]. func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - nodes := &Nodes{ - TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ - "create": func(s *Scanner, n *Nodes) bool { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "create": func(s *Scanner, n *Batch) bool { // Parse CREATE FUNCTION, CREATE PROCEDURE, CREATE TYPE, etc. c := doc.parseCreate(s, n.CreateStatements) c.Driver = &stdlib.Driver{} @@ -103,9 +103,9 @@ func (doc *PGSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { }, } - hasMore = nodes.Parse(s) - if nodes.HasErrors() { - doc.errors = append(doc.errors, nodes.Errors...) + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) } return hasMore diff --git a/sqlparser/scanner.go b/sqlparser/scanner.go index 9e6c27f..18b2b77 100644 --- a/sqlparser/scanner.go +++ b/sqlparser/scanner.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "regexp" "strings" "unicode" @@ -221,7 +222,9 @@ func (s *Scanner) nextToken() TokenType { return VariableIdentifierToken } else { rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -243,7 +246,10 @@ func (s *Scanner) nextToken() TokenType { // no, it is instead an identifier starting with N... s.scanIdentifier() rw := strings.ToLower(s.Token()) - if _, ok := reservedWords[rw]; ok { + _, ok := reservedWords[rw] + fmt.Printf("%#v %t\n", rw, ok) + + if ok { s.reservedWord = rw return ReservedWordToken } else { @@ -380,6 +386,7 @@ func (s *Scanner) scanWhitespace() TokenType { return WhitespaceToken } +// tsql (mssql) reservered words var reservedWords = map[string]struct{}{ "add": struct{}{}, "external": struct{}{}, @@ -565,8 +572,6 @@ var reservedWords = map[string]struct{}{ "writetext": struct{}{}, "exit": struct{}{}, "proc": struct{}{}, - // pgsql - "replace": struct{}{}, } // apparently 'within group' is also reserved but dropping that.. diff --git a/sqlparser/tokentype.go b/sqlparser/tokentype.go index 835e605..644d8da 100644 --- a/sqlparser/tokentype.go +++ b/sqlparser/tokentype.go @@ -38,6 +38,10 @@ const ( UnexpectedCharacterToken NonUTF8ErrorToken + // PGSQL specific + DollarQuotedStringStartToken + DollarQuotedStringEndToken + BatchSeparatorToken MalformedBatchSeparatorToken EOFToken @@ -90,6 +94,9 @@ var tokenToDescription = map[TokenType]string{ UnexpectedCharacterToken: "UnexpectedCharacterToken", NonUTF8ErrorToken: "NonUTF8ErrorToken", + DollarQuotedStringStartToken: "DollarQuotedStringEndToken", + DollarQuotedStringEndToken: "DollarQuotedStringEndToken", + // After a lot of back and forth we added the batch separater to the scanner. // We implement sqlcmd's use of the go // do separate batches. sqlcmd will only support GO at the start of diff --git a/sqlparser/tsql_document.go b/sqlparser/tsql_document.go index 8c17499..81f85bb 100644 --- a/sqlparser/tsql_document.go +++ b/sqlparser/tsql_document.go @@ -221,6 +221,7 @@ func (doc *TSqlDocument) parseBatchSeparator(s *Scanner) { // just saw a 'go'; just make sure there's nothing bad trailing it // (if there is, convert to errors and move on until the line is consumed errorEmitted := false + // continuously process tokens until a non-whitespace, non-malformed token is encountered. for { switch s.NextToken() { case WhitespaceToken: @@ -264,9 +265,9 @@ func (doc *TSqlDocument) parseDeclareBatch(s *Scanner) (hasMore bool) { } func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { - nodes := &Nodes{ - TokenHandlers: map[string]func(*Scanner, *Nodes) bool{ - "declare": func(s *Scanner, n *Nodes) bool { + batch := &Batch{ + TokenHandlers: map[string]func(*Scanner, *Batch) bool{ + "declare": func(s *Scanner, n *Batch) bool { // First declare-statement; enter a mode where we assume all contents // of batch are declare statements if !isFirst { @@ -276,7 +277,7 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { // regardless of errors, go on and parse as far as we get... return doc.parseDeclareBatch(s) }, - "create": func(s *Scanner, n *Nodes) bool { + "create": func(s *Scanner, n *Batch) bool { // should be start of create procedure or create function... c := doc.parseCreate(s, n.CreateStatements) c.Driver = &mssql.Driver{} @@ -290,9 +291,9 @@ func (doc *TSqlDocument) parseBatch(s *Scanner, isFirst bool) (hasMore bool) { }, }, } - hasMore = nodes.Parse(s) - if nodes.HasErrors() { - doc.errors = append(doc.errors, nodes.Errors...) + hasMore = batch.Parse(s) + if batch.HasErrors() { + doc.errors = append(doc.errors, batch.Errors...) } return hasMore diff --git a/sqlparser/tsql_document_test.go b/sqlparser/tsql_document_test.go index b44d575..3d07ad3 100644 --- a/sqlparser/tsql_document_test.go +++ b/sqlparser/tsql_document_test.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -228,66 +229,16 @@ func TestDocument_parseBatchSeparator(t *testing.T) { t.Run("errors on malformed separator", func(t *testing.T) { doc := &TSqlDocument{} s := NewScanner("test.sql", "go -- comment") - s.NextToken() - + tt := s.NextToken() + fmt.Printf("%#v %#v\n", s, tt) doc.parseBatchSeparator(s) + fmt.Printf("%#v\n", s) assert.NotEmpty(t, doc.errors) assert.Contains(t, doc.errors[0].Message, "should be alone") }) } -func TestDocument_parseCodeschemaName(t *testing.T) { - t.Run("parses unquoted identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[TestProc]", result.Value) - assert.NotEmpty(t, target) - }) - - t.Run("parses quoted identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].[Test Proc]") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "[Test Proc]", result.Value) - }) - - t.Run("errors on missing dot", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code] TestProc") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.errors) - assert.Contains(t, doc.errors[0].Message, "must be followed by '.'") - }) - - t.Run("errors on missing identifier", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "[code].123") - s.NextToken() - var target []Unparsed - - result := doc.parseCodeschemaName(s, &target) - - assert.Equal(t, "", result.Value) - assert.NotEmpty(t, doc.errors) - assert.Contains(t, doc.errors[0].Message, "must be followed an identifier") - }) -} - func TestDocument_parseCreate(t *testing.T) { t.Run("parses simple procedure", func(t *testing.T) { doc := &TSqlDocument{} @@ -465,60 +416,3 @@ func TestCreateUnparsed(t *testing.T) { assert.Equal(t, Pos{File: "test.sql", Line: 1, Col: 1}, unparsed.Start) }) } - -func TestDocument_recoverToNextStatement(t *testing.T) { - t.Run("recovers to declare", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "invalid tokens here declare @x int = 1") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, ReservedWordToken, s.TokenType()) - assert.Equal(t, "declare", s.ReservedWord()) - }) - - t.Run("recovers to create", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "bad stuff create procedure") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "create", s.ReservedWord()) - }) - - t.Run("recovers to go", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "error error go") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, "go", s.ReservedWord()) - }) - - t.Run("stops at EOF", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "no keywords") - s.NextToken() - - doc.recoverToNextStatement(s) - - assert.Equal(t, EOFToken, s.TokenType()) - }) -} - -func TestDocument_recoverToNextStatementCopying(t *testing.T) { - t.Run("copies tokens while recovering", func(t *testing.T) { - doc := &TSqlDocument{} - s := NewScanner("test.sql", "bad token declare") - s.NextToken() - var target []Unparsed - - doc.recoverToNextStatementCopying(s, &target) - - assert.NotEmpty(t, target) - assert.Equal(t, "declare", s.ReservedWord()) - }) -}