From abcd252041fa30a81fbddfb00f1b195aae495bb1 Mon Sep 17 00:00:00 2001 From: Shivam Bhanushali Date: Sun, 16 Nov 2025 15:58:01 +0530 Subject: [PATCH 1/2] feat: Add schema migration support for alter and drop column --- flow/connectors/postgres/cdc.go | 43 +++++++++-- flow/connectors/postgres/postgres.go | 47 +++++++++++- .../postgres/postgres_schema_delta_test.go | 76 +++++++++++++++++++ flow/internal/dynamicconf.go | 12 +++ protos/flow.proto | 2 + 5 files changed, 174 insertions(+), 6 deletions(-) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index f041814a43..b2caa8cbff 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -797,6 +797,22 @@ func PullCdcRecords[Items model.Items]( tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns)) records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) } + schemaMigrationEnabled, err := internal.PeerDBPostgresCDCMigrationEnabled(ctx, req.Env) + if err != nil { + return fmt.Errorf("error checking if schema migration is enabled: %w", err) + } + if schemaMigrationEnabled { + if len(tableSchemaDelta.DroppedColumns) > 0 { + logger.Info(fmt.Sprintf("Detected schema change for table %s, droppedColumns: %v", + tableSchemaDelta.SrcTableName, tableSchemaDelta.DroppedColumns)) + records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + } + if len(tableSchemaDelta.AlteredColumns) > 0 { + logger.Info(fmt.Sprintf("Detected schema change for table %s, alteredColumns: %v", + tableSchemaDelta.SrcTableName, tableSchemaDelta.AlteredColumns)) + records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + } + } case *model.MessageRecord[Items]: // if cdc store empty, we can move lsn, @@ -1141,6 +1157,8 @@ func processRelationMessage[Items model.Items]( AddedColumns: nil, System: prevSchema.System, NullableEnabled: prevSchema.NullableEnabled, + DroppedColumns: nil, + AlteredColumns: nil, } for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. @@ -1171,15 +1189,27 @@ func processRelationMessage[Items model.Items]( // present in previous and current relation messages, but data types have changed. // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. } else if prevRelMap[column.Name] != currRelMap[column.Name] { - p.logger.Warn(fmt.Sprintf("Detected column %s with type changed from %s to %s in table %s, but not propagating", - column.Name, prevRelMap[column.Name], currRelMap[column.Name], schemaDelta.SrcTableName)) + p.logger.Info("Detected altered column", + slog.String("columnName", column.Name), + slog.String("oldType", prevRelMap[column.Name]), + slog.String("newType", currRelMap[column.Name]), + slog.String("relationName", schemaDelta.SrcTableName)) + + schemaDelta.AlteredColumns = append(schemaDelta.AlteredColumns, &protos.FieldDescription{ + Name: column.Name, + Type: currRelMap[column.Name], + TypeModifier: column.TypeModifier, + }) } } for _, column := range prevSchema.Columns { // present in previous relation message, but not in current one, so dropped. if _, ok := currRelMap[column.Name]; !ok { - p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating", column, - schemaDelta.SrcTableName)) + p.logger.Info("Detected dropped column", + slog.String("columnName", column.Name), + slog.String("relationName", schemaDelta.SrcTableName)) + + schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name) } } if len(potentiallyNullableAddedColumns) > 0 { @@ -1214,7 +1244,10 @@ func processRelationMessage[Items model.Items]( p.relationMessageMapping[currRel.RelationID] = currRel // only log audit if there is actionable delta - if len(schemaDelta.AddedColumns) > 0 { + hasDelta := len(schemaDelta.AddedColumns) > 0 || + len(schemaDelta.DroppedColumns) > 0 || + len(schemaDelta.AlteredColumns) > 0 + if hasDelta { return &model.RelationRecord[Items]{ BaseRecord: p.baseRecord(lsn), TableSchemaDelta: schemaDelta, diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index e71ef66baf..61ed03f465 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1134,7 +1134,7 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( defer shared.RollbackTx(tableSchemaModifyTx, c.logger) for _, schemaDelta := range schemaDeltas { - if schemaDelta == nil || len(schemaDelta.AddedColumns) == 0 { + if schemaDelta == nil { continue } @@ -1164,6 +1164,51 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( slog.String("dstTableName", schemaDelta.DstTableName), ) } + + for _, droppedColumn := range schemaDelta.DroppedColumns { + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + _, err = c.execWithLoggingTx(ctx, fmt.Sprintf( + "ALTER TABLE %s.%s DROP COLUMN IF EXISTS %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(droppedColumn)), tableSchemaModifyTx) + + c.logger.Info("[schema delta replay] dropped column", + slog.String("columnName", droppedColumn), + slog.String("dstTableName", schemaDelta.DstTableName), + ) + } + + for _, alteredColumn := range schemaDelta.AlteredColumns { + columnType := alteredColumn.Type + if schemaDelta.System == protos.TypeSystem_Q { + columnType = qValueKindToPostgresType(columnType) + } + + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + + quotedColumnName := utils.QuoteIdentifier(alteredColumn.Name) + stmt := fmt.Sprintf("ALTER TABLE %s.%s ALTER COLUMN %s TYPE %s USING %s::%s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + quotedColumnName, columnType, quotedColumnName, columnType) + + _, err = c.execWithLoggingTx(ctx, stmt, tableSchemaModifyTx) + if err != nil { + return fmt.Errorf("failed to alter column %s for table %s: %w", alteredColumn.Name, schemaDelta.DstTableName, err) + } + c.logger.Info("[schema delta replay] altered column", + slog.String("columnName", alteredColumn.Name), + slog.String("newType", columnType), + slog.String("dstTableName", schemaDelta.DstTableName), + ) + } } if err := tableSchemaModifyTx.Commit(ctx); err != nil { diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index 4fc4333a1d..b2f51d18ad 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -92,6 +92,82 @@ func (s PostgresSchemaDeltaTestSuite) TestSimpleAddColumn() { }, output[tableName]) } +func (s PostgresSchemaDeltaTestSuite) TestSimpleAlterColumn() { + tableName := s.schema + ".simple_alter_column" + _, err := s.connector.conn.Exec(s.t.Context(), + fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY, hi INT)", tableName)) + require.NoError(s.t, err) + require.NoError(s.t, s.connector.ReplayTableSchemaDeltas(s.t.Context(), nil, "schema_delta_flow", nil, []*protos.TableSchemaDelta{{ + SrcTableName: tableName, + DstTableName: tableName, + AlteredColumns: []*protos.FieldDescription{ + { + Name: "hi", + Type: string(types.QValueKindBoolean), + TypeModifier: -1, + Nullable: true, + }, + }, + }})) + output, err := s.connector.GetTableSchema(s.t.Context(), nil, shared.InternalVersion_Latest, protos.TypeSystem_Q, + []*protos.TableMapping{{SourceTableIdentifier: tableName}}) + require.NoError(s.t, err) + require.Equal(s.t, &protos.TableSchema{ + TableIdentifier: tableName, + PrimaryKeyColumns: []string{"id"}, + System: protos.TypeSystem_Q, + Columns: []*protos.FieldDescription{ + { + Name: "id", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + }, + { + Name: "hi", + Type: string(types.QValueKindBoolean), + TypeModifier: -1, + Nullable: true, + }, + }, + }, output[tableName]) + fmt.Println("output", output) +} + +func (s PostgresSchemaDeltaTestSuite) TestSimpleDropColumn() { + tableName := s.schema + ".simple_drop_column" + _, err := s.connector.conn.Exec(s.t.Context(), + fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY, hi INT, bye TEXT)", tableName)) + require.NoError(s.t, err) + + require.NoError(s.t, s.connector.ReplayTableSchemaDeltas(s.t.Context(), nil, "schema_delta_flow", nil, []*protos.TableSchemaDelta{{ + SrcTableName: tableName, + DstTableName: tableName, + DroppedColumns: []string{"hi"}, + }})) + + output, err := s.connector.GetTableSchema(s.t.Context(), nil, shared.InternalVersion_Latest, protos.TypeSystem_Q, + []*protos.TableMapping{{SourceTableIdentifier: tableName}}) + require.NoError(s.t, err) + require.Equal(s.t, &protos.TableSchema{ + TableIdentifier: tableName, + PrimaryKeyColumns: []string{"id"}, + System: protos.TypeSystem_Q, + Columns: []*protos.FieldDescription{ + { + Name: "id", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + }, + { + Name: "bye", + Type: string(types.QValueKindString), + TypeModifier: -1, + Nullable: true, + }, + }, + }, output[tableName]) +} + func (s PostgresSchemaDeltaTestSuite) TestAddAllColumnTypes() { tableName := s.schema + ".add_drop_all_column_types" _, err := s.connector.conn.Exec(s.t.Context(), diff --git a/flow/internal/dynamicconf.go b/flow/internal/dynamicconf.go index 5659641ab4..dd321763ce 100644 --- a/flow/internal/dynamicconf.go +++ b/flow/internal/dynamicconf.go @@ -387,6 +387,14 @@ var DynamicSettings = [...]*protos.DynamicSetting{ ApplyMode: protos.DynconfApplyMode_APPLY_MODE_AFTER_RESUME, TargetForSetting: protos.DynconfTarget_ALL, }, + { + Name: "PEERDB_POSTGRES_CDC_SCHEMA_MIGRATION_ENABLED", + Description: "Enable/disable schema migration (alter and drop columns) for Postgres CDC", + DefaultValue: "false", + ValueType: protos.DynconfValueType_BOOL, + ApplyMode: protos.DynconfApplyMode_APPLY_MODE_IMMEDIATE, + TargetForSetting: protos.DynconfTarget_ALL, + }, } var DynamicIndex = func() map[string]int { @@ -715,3 +723,7 @@ func PeerDBPostgresWalSenderTimeout(ctx context.Context, env map[string]string) func PeerDBMetricsRecordAggregatesEnabled(ctx context.Context, env map[string]string) (bool, error) { return dynamicConfBool(ctx, env, "PEERDB_METRICS_RECORD_AGGREGATES_ENABLED") } + +func PeerDBPostgresCDCMigrationEnabled(ctx context.Context, env map[string]string) (bool, error) { + return dynamicConfBool(ctx, env, "PEERDB_POSTGRES_CDC_SCHEMA_MIGRATION_ENABLED") +} diff --git a/protos/flow.proto b/protos/flow.proto index ee94edfadb..6b85652003 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -431,6 +431,8 @@ message TableSchemaDelta { repeated FieldDescription added_columns = 3; TypeSystem system = 4; bool nullable_enabled = 5; + repeated string dropped_columns = 6; + repeated FieldDescription altered_columns = 7; } message QRepFlowState { From 3fdbaa02c8fdf054ab39071cefb738787580d1a1 Mon Sep 17 00:00:00 2001 From: Shivam Bhanushali Date: Sun, 16 Nov 2025 21:28:31 +0530 Subject: [PATCH 2/2] feat: Add schema index migration support for adding and dropping index --- flow/connectors/postgres/cdc.go | 93 +++++++++++++++- flow/connectors/postgres/client.go | 83 ++++++++++++++ flow/connectors/postgres/postgres.go | 104 ++++++++++++++++++ .../postgres/postgres_schema_delta_test.go | 88 +++++++++++++++ flow/internal/dynamicconf.go | 12 ++ protos/flow.proto | 18 +++ 6 files changed, 397 insertions(+), 1 deletion(-) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index b2caa8cbff..0dd6bfef0e 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -813,6 +813,22 @@ func PullCdcRecords[Items model.Items]( records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) } } + schemaMigrationIndexEnabled, err := internal.PeerDBPostgresCDCMigrationIndexEnabled(ctx, req.Env) + if err != nil { + return fmt.Errorf("error checking if schema migration index is enabled: %w", err) + } + if schemaMigrationIndexEnabled { + if len(tableSchemaDelta.AddedIndexes) > 0 { + logger.Info(fmt.Sprintf("Detected schema change for table %s, addedIndexes: %v", + tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedIndexes)) + records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + } + if len(tableSchemaDelta.DroppedIndexes) > 0 { + logger.Info(fmt.Sprintf("Detected schema change for table %s, droppedIndexes: %v", + tableSchemaDelta.SrcTableName, tableSchemaDelta.DroppedIndexes)) + records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + } + } case *model.MessageRecord[Items]: // if cdc store empty, we can move lsn, @@ -1159,6 +1175,8 @@ func processRelationMessage[Items model.Items]( NullableEnabled: prevSchema.NullableEnabled, DroppedColumns: nil, AlteredColumns: nil, + AddedIndexes: nil, + DroppedIndexes: nil, } for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. @@ -1242,11 +1260,36 @@ func processRelationMessage[Items model.Items]( } } + srcSchemaTable, err := utils.ParseSchemaTable(schemaDelta.SrcTableName) + if err != nil { + return nil, fmt.Errorf("error parsing source table name %s for index diff: %w", schemaDelta.SrcTableName, err) + } + + currIndexes, err := p.getIndexesForTable(ctx, currRel.RelationID, srcSchemaTable) + if err != nil { + return nil, fmt.Errorf("error getting indexes for relation %s: %w", schemaDelta.SrcTableName, err) + } + + addedIdx, droppedIdx := diffIndexes(prevSchema.Indexes, currIndexes) + if len(addedIdx) > 0 || len(droppedIdx) > 0 { + schemaDelta.AddedIndexes = addedIdx + schemaDelta.DroppedIndexes = droppedIdx + p.logger.Info("Detected index changes", + slog.Any("added", addedIdx), + slog.Any("dropped", droppedIdx), + slog.String("relationName", schemaDelta.SrcTableName)) + } + + // Update in-memory schema so future diffs are incremental + prevSchema.Indexes = currIndexes + p.relationMessageMapping[currRel.RelationID] = currRel // only log audit if there is actionable delta hasDelta := len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || - len(schemaDelta.AlteredColumns) > 0 + len(schemaDelta.AlteredColumns) > 0 || + len(schemaDelta.AddedIndexes) > 0 || + len(schemaDelta.DroppedIndexes) > 0 if hasDelta { return &model.RelationRecord[Items]{ BaseRecord: p.baseRecord(lsn), @@ -1314,3 +1357,51 @@ func (p *PostgresCDCSource) checkIfUnknownTableInherits(ctx context.Context, return relID, nil } + +func indexesEqual(a, b *protos.IndexDescription) bool { + if a == nil || b == nil { + return a == b + } + if a.Name != b.Name || a.Method != b.Method || a.IsUnique != b.IsUnique || a.Where != b.Where { + return false + } + if !slices.Equal(a.ColumnNames, b.ColumnNames) { + return false + } + if !slices.Equal(a.IncludeColumns, b.IncludeColumns) { + return false + } + return true +} + +func diffIndexes( + prev, curr []*protos.IndexDescription, +) (added []*protos.IndexDescription, dropped []*protos.IndexDescription) { + prevMap := make(map[string]*protos.IndexDescription, len(prev)) + for _, idx := range prev { + if idx == nil { + continue + } + prevMap[idx.Name] = idx + } + currMap := make(map[string]*protos.IndexDescription, len(curr)) + for _, idx := range curr { + if idx == nil { + continue + } + currMap[idx.Name] = idx + } + + for name, cIdx := range currMap { + pIdx, ok := prevMap[name] + if !ok || !indexesEqual(pIdx, cIdx) { + added = append(added, cIdx) + } + } + for name := range prevMap { + if _, ok := currMap[name]; !ok { + dropped = append(dropped, prevMap[name]) + } + } + return added, dropped +} diff --git a/flow/connectors/postgres/client.go b/flow/connectors/postgres/client.go index 05784993ac..f32b531af1 100644 --- a/flow/connectors/postgres/client.go +++ b/flow/connectors/postgres/client.go @@ -207,6 +207,68 @@ func (c *PostgresConnector) getColumnNamesForIndex(ctx context.Context, indexOID return cols, nil } +// getIndexesForTable returns all non-primary indexes for the given table. +func (c *PostgresConnector) getIndexesForTable( + ctx context.Context, + relID uint32, + schemaTable *utils.SchemaTable, +) ([]*protos.IndexDescription, error) { + query := ` + SELECT + ci.relname AS index_name, + am.amname AS method, + i.indisunique AS is_unique, + pg_get_expr(i.indpred, i.indrelid) AS predicate, + ( + SELECT array_agg(a.attname ORDER BY u.nr) + FROM unnest(i.indkey) WITH ORDINALITY AS u(attnum, nr) + JOIN pg_attribute AS a ON a.attrelid = i.indrelid AND a.attnum = u.attnum + ) as column_names + FROM pg_index i + JOIN pg_class ct ON ct.oid = i.indrelid + JOIN pg_namespace ns ON ns.oid = ct.relnamespace + JOIN pg_class ci ON ci.oid = i.indexrelid + JOIN pg_am am ON am.oid = ci.relam + WHERE i.indrelid = $1 + AND ns.nspname = $2 + AND ct.relname = $3 + AND NOT i.indisprimary` + + rows, err := c.conn.Query(ctx, query, relID, schemaTable.Schema, schemaTable.Table) + if err != nil { + return nil, fmt.Errorf("error querying indexes for table %s: %w", schemaTable, err) + } + + var indexes []*protos.IndexDescription + var ( + name string + method string + isUnique bool + whereClause pgtype.Text + columnNames []string + ) + _, err = pgx.ForEachRow(rows, []any{&name, &method, &isUnique, &whereClause, &columnNames}, func() error { + where := "" + if whereClause.Valid { + where = whereClause.String + } + indexes = append(indexes, &protos.IndexDescription{ + Name: name, + ColumnNames: columnNames, + Method: method, + IsUnique: isUnique, + Where: where, + IncludeColumns: nil, // can be filled later if you also query INCLUDE columns + }) + return nil + }) + if err != nil { + return nil, fmt.Errorf("error iterating index rows for table %s: %w", schemaTable, err) + } + + return indexes, nil +} + func (c *PostgresConnector) getNullableColumns(ctx context.Context, relID uint32) (map[string]struct{}, error) { rows, err := c.conn.Query(ctx, "SELECT a.attname FROM pg_attribute a WHERE a.attrelid = $1 AND NOT a.attnotnull", relID) if err != nil { @@ -525,6 +587,27 @@ func generateCreateTableSQLForNormalizedTable( return fmt.Sprintf(createNormalizedTableSQL, dstSchemaTable.String(), strings.Join(createTableSQLArray, ",")) } +func generateCreateIndexesSQLForNormalizedTable( + dstSchemaTable *utils.SchemaTable, + tableSchema *protos.TableSchema, +) []string { + if tableSchema == nil || len(tableSchema.Indexes) == 0 { + return nil + } + stmts := make([]string, 0, len(tableSchema.Indexes)) + for _, idx := range tableSchema.Indexes { + if idx == nil || len(idx.ColumnNames) == 0 || idx.Name == "" { + continue + } + + stmt, _ := buildCreateIndexStmt(dstSchemaTable.String(), idx) + + stmts = append(stmts, stmt) + } + + return stmts +} + func (c *PostgresConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) { var result pgtype.Int8 if err := c.conn.QueryRow(ctx, fmt.Sprintf( diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 61ed03f465..cd0f2a7e39 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1002,6 +1002,12 @@ func (c *PostgresConnector) getTableSchemaForTable( selectedColumnsStr = strings.Join(selectedColumns, ",") } + var indexes []*protos.IndexDescription + indexes, err = c.getIndexesForTable(ctx, relID, schemaTable) + if err != nil { + return nil, fmt.Errorf("error getting indexes for table %s: %w", schemaTable, err) + } + // Get the column names and types rows, err := c.conn.Query(ctx, fmt.Sprintf(`SELECT %s FROM %s LIMIT 0`, selectedColumnsStr, schemaTable.String()), @@ -1053,6 +1059,7 @@ func (c *PostgresConnector) getTableSchemaForTable( Columns: columns, NullableEnabled: nullableEnabled, System: system, + Indexes: indexes, }, nil } @@ -1109,6 +1116,15 @@ func (c *PostgresConnector) SetupNormalizedTable( return false, fmt.Errorf("error while creating normalized table: %w", err) } + // create non-primary indexes on the normalized table + c.logger.Info("Creating non-primary indexes on the normalized table", slog.String("table", parsedNormalizedTable.String()), slog.Any("tableSchema", tableSchema)) + for _, stmt := range generateCreateIndexesSQLForNormalizedTable(parsedNormalizedTable, tableSchema) { + c.logger.Info("Creating index", slog.String("index", stmt)) + if _, err := c.execWithLoggingTx(ctx, stmt, createNormalizedTablesTx); err != nil { + return false, fmt.Errorf("error while creating index on normalized table %s: %w", parsedNormalizedTable.String(), err) + } + } + return false, nil } @@ -1209,6 +1225,46 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( slog.String("dstTableName", schemaDelta.DstTableName), ) } + + // indexes dropped + for _, idxName := range schemaDelta.DroppedIndexes { + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(idxName.Name)) + if _, err := c.execWithLoggingTx(ctx, stmt, tableSchemaModifyTx); err != nil { + return fmt.Errorf("failed to drop index %s for table %s: %w", idxName, schemaDelta.DstTableName, err) + } + c.logger.Info("[schema delta replay] dropped index", + slog.String("indexName", idxName.Name), + slog.String("dstTableName", schemaDelta.DstTableName), + ) + } + + // indexes added + for _, idx := range schemaDelta.AddedIndexes { + if idx == nil || len(idx.ColumnNames) == 0 || idx.Name == "" { + continue + } + + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + + stmt, _ := buildCreateIndexStmt(dstSchemaTable.String(), idx) + + if _, err := c.execWithLoggingTx(ctx, stmt, tableSchemaModifyTx); err != nil { + return fmt.Errorf("failed to create index %s for table %s: %w", idx.Name, schemaDelta.DstTableName, err) + } + c.logger.Info("[schema delta replay] added index", + slog.String("indexName", idx.Name), + slog.String("dstTableName", schemaDelta.DstTableName), + ) + } } if err := tableSchemaModifyTx.Commit(ctx); err != nil { @@ -1217,6 +1273,54 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( return nil } +func buildCreateIndexStmt(tableIdent string, idx *protos.IndexDescription) (string, bool) { + if idx == nil || len(idx.ColumnNames) == 0 || idx.Name == "" { + return "", false + } + + var unique string + if idx.IsUnique { + unique = "UNIQUE " + } + + colIdents := make([]string, 0, len(idx.ColumnNames)) + for _, cName := range idx.ColumnNames { + colIdents = append(colIdents, utils.QuoteIdentifier(cName)) + } + + includeClause := "" + if len(idx.IncludeColumns) > 0 { + includeCols := make([]string, 0, len(idx.IncludeColumns)) + for _, cName := range idx.IncludeColumns { + includeCols = append(includeCols, utils.QuoteIdentifier(cName)) + } + includeClause = " INCLUDE (" + strings.Join(includeCols, ",") + ")" + } + + methodClause := "" + if idx.Method != "" && strings.ToLower(idx.Method) != "btree" { + methodClause = " USING " + idx.Method + } + + whereClause := "" + if idx.Where != "" { + whereClause = " WHERE " + idx.Where + } + + stmt := fmt.Sprintf( + "CREATE %sINDEX IF NOT EXISTS %s ON %s%s (%s)%s%s", + unique, + utils.QuoteIdentifier(idx.Name), + tableIdent, + methodClause, + strings.Join(colIdents, ","), + includeClause, + whereClause, + ) + + return stmt, true +} + // EnsurePullability ensures that a table is pullable, implementing the Connector interface. func (c *PostgresConnector) EnsurePullability( ctx context.Context, diff --git a/flow/connectors/postgres/postgres_schema_delta_test.go b/flow/connectors/postgres/postgres_schema_delta_test.go index b2f51d18ad..4b19acd13a 100644 --- a/flow/connectors/postgres/postgres_schema_delta_test.go +++ b/flow/connectors/postgres/postgres_schema_delta_test.go @@ -261,6 +261,94 @@ func (s PostgresSchemaDeltaTestSuite) TestAddDropWhitespaceColumnNames() { require.Equal(s.t, expectedTableSchema, output[tableName]) } +func (s PostgresSchemaDeltaTestSuite) TestSimpleDropIndex() { + tableName := s.schema + ".simple_drop_index" + _, err := s.connector.conn.Exec(s.t.Context(), + fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY, hi INT); CREATE INDEX idx_hi ON %s(hi);", tableName, tableName)) + require.NoError(s.t, err) + expectedTableSchema := &protos.TableSchema{ + TableIdentifier: tableName, + PrimaryKeyColumns: []string{"id"}, + Columns: []*protos.FieldDescription{ + { + Name: "id", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + }, + { + Name: "hi", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + Nullable: true, + }, + }, + System: protos.TypeSystem_Q, + } + + require.NoError(s.t, s.connector.ReplayTableSchemaDeltas(s.t.Context(), nil, "schema_delta_flow", nil, []*protos.TableSchemaDelta{{ + SrcTableName: tableName, + DstTableName: tableName, + DroppedIndexes: []*protos.IndexDescription{{Name: "idx_hi"}}, + }})) + output, err := s.connector.GetTableSchema(s.t.Context(), nil, shared.InternalVersion_Latest, protos.TypeSystem_Q, + []*protos.TableMapping{{SourceTableIdentifier: tableName}}) + require.NoError(s.t, err) + require.Equal(s.t, expectedTableSchema, output[tableName]) +} + +func (s PostgresSchemaDeltaTestSuite) TestSimpleAddIndex() { + tableName := s.schema + ".simple_add_index" + _, err := s.connector.conn.Exec(s.t.Context(), + fmt.Sprintf("CREATE TABLE %s(id INT PRIMARY KEY, hi INT)", tableName)) + require.NoError(s.t, err) + + expectedTableSchema := &protos.TableSchema{ + TableIdentifier: tableName, + PrimaryKeyColumns: []string{"id"}, + Columns: []*protos.FieldDescription{ + { + Name: "id", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + }, + { + Name: "hi", + Type: string(types.QValueKindInt32), + TypeModifier: -1, + Nullable: true, + }, + }, + System: protos.TypeSystem_Q, + Indexes: []*protos.IndexDescription{ + { + Name: "idx_hi", + ColumnNames: []string{"hi"}, + Method: "btree", + IsUnique: false, + Where: "", + }, + }, + } + addedIndexes := make([]*protos.IndexDescription, 0) + for _, index := range expectedTableSchema.Indexes { + if index.Name != "id" { + addedIndexes = append(addedIndexes, index) + } + } + + require.NoError(s.t, s.connector.ReplayTableSchemaDeltas(s.t.Context(), nil, "schema_delta_flow", nil, []*protos.TableSchemaDelta{{ + SrcTableName: tableName, + DstTableName: tableName, + AddedIndexes: addedIndexes, + }})) + + output, err := s.connector.GetTableSchema(s.t.Context(), nil, shared.InternalVersion_Latest, protos.TypeSystem_Q, + []*protos.TableMapping{{SourceTableIdentifier: tableName}}) + + require.NoError(s.t, err) + require.Equal(s.t, expectedTableSchema, output[tableName]) +} + func TestPostgresSchemaDeltaTestSuite(t *testing.T) { e2eshared.RunSuite(t, SetupSuite) } diff --git a/flow/internal/dynamicconf.go b/flow/internal/dynamicconf.go index dd321763ce..c6c7e99d9e 100644 --- a/flow/internal/dynamicconf.go +++ b/flow/internal/dynamicconf.go @@ -395,6 +395,14 @@ var DynamicSettings = [...]*protos.DynamicSetting{ ApplyMode: protos.DynconfApplyMode_APPLY_MODE_IMMEDIATE, TargetForSetting: protos.DynconfTarget_ALL, }, + { + Name: "PEERDB_POSTGRES_CDC_SCHEMA_MIGRATION_INDEX_ENABLED", + Description: "Enable/disable schema migration (add and drop indexes) for Postgres CDC", + DefaultValue: "false", + ValueType: protos.DynconfValueType_BOOL, + ApplyMode: protos.DynconfApplyMode_APPLY_MODE_IMMEDIATE, + TargetForSetting: protos.DynconfTarget_ALL, + }, } var DynamicIndex = func() map[string]int { @@ -727,3 +735,7 @@ func PeerDBMetricsRecordAggregatesEnabled(ctx context.Context, env map[string]st func PeerDBPostgresCDCMigrationEnabled(ctx context.Context, env map[string]string) (bool, error) { return dynamicConfBool(ctx, env, "PEERDB_POSTGRES_CDC_SCHEMA_MIGRATION_ENABLED") } + +func PeerDBPostgresCDCMigrationIndexEnabled(ctx context.Context, env map[string]string) (bool, error) { + return dynamicConfBool(ctx, env, "PEERDB_POSTGRES_CDC_SCHEMA_MIGRATION_INDEX_ENABLED") +} diff --git a/protos/flow.proto b/protos/flow.proto index 6b85652003..d947b7564f 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -233,6 +233,7 @@ message TableSchema { TypeSystem system = 4; bool nullable_enabled = 5; repeated FieldDescription columns = 6; + repeated IndexDescription indexes = 7; } message FieldDescription { @@ -242,6 +243,21 @@ message FieldDescription { bool nullable = 4; } +message IndexDescription { + // index name on source + string name = 1; + // key columns in order + repeated string column_names = 2; + // index method, e.g. "btree", "gin", "gist", "hash" + string method = 3; + // whether index is UNIQUE + bool is_unique = 4; + // WHERE predicate for partial indexes, empty if not partial + string where = 5; + // INCLUDE columns (Postgres INCLUDE clause) + repeated string include_columns = 6; +} + message SetupTableSchemaBatchInput { reserved 2; map env = 1; @@ -433,6 +449,8 @@ message TableSchemaDelta { bool nullable_enabled = 5; repeated string dropped_columns = 6; repeated FieldDescription altered_columns = 7; + repeated IndexDescription added_indexes = 8; + repeated IndexDescription dropped_indexes = 9; } message QRepFlowState {