Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 129 additions & 5 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,38 @@
tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns))
records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta)
}
schemaMigrationEnabled, err := internal.PeerDBPostgresCDCMigrationEnabled(ctx, req.Env)
if err != nil {
return fmt.Errorf("error checking if schema migration is enabled: %w", err)
}
if schemaMigrationEnabled {
if len(tableSchemaDelta.DroppedColumns) > 0 {
logger.Info(fmt.Sprintf("Detected schema change for table %s, droppedColumns: %v",
tableSchemaDelta.SrcTableName, tableSchemaDelta.DroppedColumns))
records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta)
}
if len(tableSchemaDelta.AlteredColumns) > 0 {
logger.Info(fmt.Sprintf("Detected schema change for table %s, alteredColumns: %v",
tableSchemaDelta.SrcTableName, tableSchemaDelta.AlteredColumns))
records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta)
}
}
schemaMigrationIndexEnabled, err := internal.PeerDBPostgresCDCMigrationIndexEnabled(ctx, req.Env)
if err != nil {
return fmt.Errorf("error checking if schema migration index is enabled: %w", err)
}
if schemaMigrationIndexEnabled {
if len(tableSchemaDelta.AddedIndexes) > 0 {
logger.Info(fmt.Sprintf("Detected schema change for table %s, addedIndexes: %v",
tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedIndexes))
records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta)
}
if len(tableSchemaDelta.DroppedIndexes) > 0 {
logger.Info(fmt.Sprintf("Detected schema change for table %s, droppedIndexes: %v",
tableSchemaDelta.SrcTableName, tableSchemaDelta.DroppedIndexes))
records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta)
}
}

case *model.MessageRecord[Items]:
// if cdc store empty, we can move lsn,
Expand Down Expand Up @@ -1141,6 +1173,10 @@
AddedColumns: nil,
System: prevSchema.System,
NullableEnabled: prevSchema.NullableEnabled,
DroppedColumns: nil,
AlteredColumns: nil,
AddedIndexes: nil,
DroppedIndexes: nil,
}
for _, column := range currRel.Columns {
// not present in previous relation message, but in current one, so added.
Expand Down Expand Up @@ -1171,15 +1207,27 @@
// present in previous and current relation messages, but data types have changed.
// so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first.
} else if prevRelMap[column.Name] != currRelMap[column.Name] {
p.logger.Warn(fmt.Sprintf("Detected column %s with type changed from %s to %s in table %s, but not propagating",
column.Name, prevRelMap[column.Name], currRelMap[column.Name], schemaDelta.SrcTableName))
p.logger.Info("Detected altered column",
slog.String("columnName", column.Name),
slog.String("oldType", prevRelMap[column.Name]),
slog.String("newType", currRelMap[column.Name]),
slog.String("relationName", schemaDelta.SrcTableName))

schemaDelta.AlteredColumns = append(schemaDelta.AlteredColumns, &protos.FieldDescription{
Name: column.Name,
Type: currRelMap[column.Name],
TypeModifier: column.TypeModifier,
})
}
}
for _, column := range prevSchema.Columns {
// present in previous relation message, but not in current one, so dropped.
if _, ok := currRelMap[column.Name]; !ok {
p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating", column,
schemaDelta.SrcTableName))
p.logger.Info("Detected dropped column",
slog.String("columnName", column.Name),
slog.String("relationName", schemaDelta.SrcTableName))

schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name)
}
}
if len(potentiallyNullableAddedColumns) > 0 {
Expand Down Expand Up @@ -1212,9 +1260,37 @@
}
}

srcSchemaTable, err := utils.ParseSchemaTable(schemaDelta.SrcTableName)
if err != nil {
return nil, fmt.Errorf("error parsing source table name %s for index diff: %w", schemaDelta.SrcTableName, err)
}

currIndexes, err := p.getIndexesForTable(ctx, currRel.RelationID, srcSchemaTable)
if err != nil {
return nil, fmt.Errorf("error getting indexes for relation %s: %w", schemaDelta.SrcTableName, err)
}

addedIdx, droppedIdx := diffIndexes(prevSchema.Indexes, currIndexes)
if len(addedIdx) > 0 || len(droppedIdx) > 0 {
schemaDelta.AddedIndexes = addedIdx
schemaDelta.DroppedIndexes = droppedIdx
p.logger.Info("Detected index changes",
slog.Any("added", addedIdx),
slog.Any("dropped", droppedIdx),
slog.String("relationName", schemaDelta.SrcTableName))
}

// Update in-memory schema so future diffs are incremental
prevSchema.Indexes = currIndexes

p.relationMessageMapping[currRel.RelationID] = currRel
// only log audit if there is actionable delta
if len(schemaDelta.AddedColumns) > 0 {
hasDelta := len(schemaDelta.AddedColumns) > 0 ||
len(schemaDelta.DroppedColumns) > 0 ||
len(schemaDelta.AlteredColumns) > 0 ||
len(schemaDelta.AddedIndexes) > 0 ||
len(schemaDelta.DroppedIndexes) > 0
if hasDelta {
return &model.RelationRecord[Items]{
BaseRecord: p.baseRecord(lsn),
TableSchemaDelta: schemaDelta,
Expand Down Expand Up @@ -1281,3 +1357,51 @@

return relID, nil
}

func indexesEqual(a, b *protos.IndexDescription) bool {
if a == nil || b == nil {
return a == b
}
if a.Name != b.Name || a.Method != b.Method || a.IsUnique != b.IsUnique || a.Where != b.Where {
return false
}
if !slices.Equal(a.ColumnNames, b.ColumnNames) {
return false
}
if !slices.Equal(a.IncludeColumns, b.IncludeColumns) {
return false
}
return true
}

func diffIndexes(

Check failure on line 1377 in flow/connectors/postgres/cdc.go

View workflow job for this annotation

GitHub Actions / lint

named return "added" with type "[]*protos.IndexDescription" found (nonamedreturns)
prev, curr []*protos.IndexDescription,
) (added []*protos.IndexDescription, dropped []*protos.IndexDescription) {
prevMap := make(map[string]*protos.IndexDescription, len(prev))
for _, idx := range prev {
if idx == nil {
continue
}
prevMap[idx.Name] = idx
}
currMap := make(map[string]*protos.IndexDescription, len(curr))
for _, idx := range curr {
if idx == nil {
continue
}
currMap[idx.Name] = idx
}

for name, cIdx := range currMap {
pIdx, ok := prevMap[name]
if !ok || !indexesEqual(pIdx, cIdx) {
added = append(added, cIdx)
}
}
for name := range prevMap {
if _, ok := currMap[name]; !ok {
dropped = append(dropped, prevMap[name])
}
}
return added, dropped
}
83 changes: 83 additions & 0 deletions flow/connectors/postgres/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,68 @@ func (c *PostgresConnector) getColumnNamesForIndex(ctx context.Context, indexOID
return cols, nil
}

// getIndexesForTable returns all non-primary indexes for the given table.
func (c *PostgresConnector) getIndexesForTable(
ctx context.Context,
relID uint32,
schemaTable *utils.SchemaTable,
) ([]*protos.IndexDescription, error) {
query := `
SELECT
ci.relname AS index_name,
am.amname AS method,
i.indisunique AS is_unique,
pg_get_expr(i.indpred, i.indrelid) AS predicate,
(
SELECT array_agg(a.attname ORDER BY u.nr)
FROM unnest(i.indkey) WITH ORDINALITY AS u(attnum, nr)
JOIN pg_attribute AS a ON a.attrelid = i.indrelid AND a.attnum = u.attnum
) as column_names
FROM pg_index i
JOIN pg_class ct ON ct.oid = i.indrelid
JOIN pg_namespace ns ON ns.oid = ct.relnamespace
JOIN pg_class ci ON ci.oid = i.indexrelid
JOIN pg_am am ON am.oid = ci.relam
WHERE i.indrelid = $1
AND ns.nspname = $2
AND ct.relname = $3
AND NOT i.indisprimary`

rows, err := c.conn.Query(ctx, query, relID, schemaTable.Schema, schemaTable.Table)
if err != nil {
return nil, fmt.Errorf("error querying indexes for table %s: %w", schemaTable, err)
}

var indexes []*protos.IndexDescription
var (
name string
method string
isUnique bool
whereClause pgtype.Text
columnNames []string
)
_, err = pgx.ForEachRow(rows, []any{&name, &method, &isUnique, &whereClause, &columnNames}, func() error {
where := ""
if whereClause.Valid {
where = whereClause.String
}
indexes = append(indexes, &protos.IndexDescription{
Name: name,
ColumnNames: columnNames,
Method: method,
IsUnique: isUnique,
Where: where,
IncludeColumns: nil, // can be filled later if you also query INCLUDE columns
})
return nil
})
if err != nil {
return nil, fmt.Errorf("error iterating index rows for table %s: %w", schemaTable, err)
}

return indexes, nil
}

func (c *PostgresConnector) getNullableColumns(ctx context.Context, relID uint32) (map[string]struct{}, error) {
rows, err := c.conn.Query(ctx, "SELECT a.attname FROM pg_attribute a WHERE a.attrelid = $1 AND NOT a.attnotnull", relID)
if err != nil {
Expand Down Expand Up @@ -525,6 +587,27 @@ func generateCreateTableSQLForNormalizedTable(
return fmt.Sprintf(createNormalizedTableSQL, dstSchemaTable.String(), strings.Join(createTableSQLArray, ","))
}

func generateCreateIndexesSQLForNormalizedTable(
dstSchemaTable *utils.SchemaTable,
tableSchema *protos.TableSchema,
) []string {
if tableSchema == nil || len(tableSchema.Indexes) == 0 {
return nil
}
stmts := make([]string, 0, len(tableSchema.Indexes))
for _, idx := range tableSchema.Indexes {
if idx == nil || len(idx.ColumnNames) == 0 || idx.Name == "" {
continue
}

stmt, _ := buildCreateIndexStmt(dstSchemaTable.String(), idx)

stmts = append(stmts, stmt)
}

return stmts
}

func (c *PostgresConnector) GetLastSyncBatchID(ctx context.Context, jobName string) (int64, error) {
var result pgtype.Int8
if err := c.conn.QueryRow(ctx, fmt.Sprintf(
Expand Down
Loading
Loading