diff --git a/custody.go b/custody.go index 852d22f..9981a29 100644 --- a/custody.go +++ b/custody.go @@ -287,7 +287,7 @@ func (c *Custody) handleBlockChainEvent(l types.Log) { return nil }) if err != nil { - log.Printf("[Closed] Error closing channel in database: %v", err) + log.Printf("[Closed] Error closing channel: %v", err) return } c.sendBalanceUpdate(channel.Participant) @@ -296,33 +296,59 @@ func (c *Custody) handleBlockChainEvent(l types.Log) { case custodyAbi.Events["Resized"].ID: ev, err := c.custody.ParseResized(l) if err != nil { - log.Println("error parsing ChannelResized event:", err) + log.Println("error parsing Resized event:", err) return } log.Printf("Resized event data: %+v\n", ev) - channelID := common.BytesToHash(ev.ChannelId[:]) - var channel Channel - result := c.db.Where("channel_id = ?", channelID.Hex()).First(&channel) - if result.Error != nil { - log.Println("error finding channel:", result.Error) - return - } + err = c.db.Transaction(func(tx *gorm.DB) error { + channelID := common.BytesToHash(ev.ChannelId[:]).Hex() + result := c.db.Where("channel_id = ?", channelID).First(&channel) + if result.Error != nil { + return fmt.Errorf("error finding channel: %w", result.Error) + } - newAmount := int64(channel.Amount) - for _, change := range ev.DeltaAllocations { - newAmount += change.Int64() - } + newAmount := int64(channel.Amount) + for _, change := range ev.DeltaAllocations { + newAmount += change.Int64() + } + + channel.Amount = uint64(newAmount) + channel.UpdatedAt = time.Now() + channel.Version++ + if err := c.db.Save(&channel).Error; err != nil { + return fmt.Errorf("[Resized] Error saving channel in database: %w", err) + } - channel.Amount = uint64(newAmount) - channel.UpdatedAt = time.Now() - channel.Version++ - if err := c.db.Save(&channel).Error; err != nil { - log.Printf("[Resized] Error saving channel in database: %v", err) + resizeAmount := ev.DeltaAllocations[0] // Participant deposits or withdraws. + if resizeAmount.Cmp(big.NewInt(0)) != 0 { + asset, err := GetAssetByToken(tx, channel.Token, c.chainID) + if err != nil { + return fmt.Errorf("DB error fetching asset: %w", err) + } + + if asset == nil { + return fmt.Errorf("Asset not found in database for token: %s", channel.Token) + } + + amount := decimal.NewFromBigInt(resizeAmount, -int32(asset.Decimals)) + ledger := GetParticipantLedger(tx, channel.Participant) + if err := ledger.Record(channel.Participant, asset.Symbol, amount); err != nil { + log.Printf("[Resized] Error recording balance update for participant: %v", err) + return err + } + } + + return nil + }) + + if err != nil { + log.Printf("[Resized] Error resizing channel: %v", err) return } + c.sendBalanceUpdate(channel.Participant) c.sendChannelUpdate(channel) default: log.Println("Unknown event ID:", eventID.Hex())