diff --git a/gotype.go b/gotype.go index 6c55fe6..7188591 100644 --- a/gotype.go +++ b/gotype.go @@ -92,7 +92,7 @@ func avroTypeOfUncached(names *Names, t reflect.Type) (*Type, error) { // TODO pass in wType so that we can determine a schema // even for partially specified Go types (e.g. interface{} values) // See https://github.com/heetch/avro/issues/34 - schemaVal, err := gts.schemaForGoType(t) + schemaVal, err := gts.schemaForGoType(t, false) if err != nil { return nil, err } @@ -129,9 +129,11 @@ type goTypeSchema struct { defs map[reflect.Type]goTypeDef } -func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { +// `ignoreCache` parameter prevents reusing registered type for an Anonymous field in a struct +// This is helpful since the Anonymous fields must be merged to the current struct +func (gts *goTypeSchema) schemaForGoType(t reflect.Type, ignoreCache bool) (interface{}, error) { d, ok := gts.defs[t] - if ok { + if !ignoreCache && ok { // We've already defined a name for this type, so use it. return d.name, nil } @@ -167,7 +169,7 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { if t.Elem() == byteType { return "bytes", nil } - items, err := gts.schemaForGoType(t.Elem()) + items, err := gts.schemaForGoType(t.Elem(), false) if err != nil { return nil, err } @@ -180,7 +182,7 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { if t.Key().Kind() != reflect.String { return nil, fmt.Errorf("map must have string key") } - values, err := gts.schemaForGoType(t.Elem()) + values, err := gts.schemaForGoType(t.Elem(), false) if err != nil { return nil, err } @@ -198,25 +200,36 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { case nullType: return "null", nil } + // Define the struct type before filling in the definition // so that we'll find the definition if there's a recursive type. // The map returned by the define method holds a reference // to the same object held in gts.defs, so changing it // below will update the final definition. - def, err := gts.define(t, map[string]interface{}{ - "type": "record", - }, "") - if err != nil { - return nil, err + def := map[string]interface{}{} + if !ignoreCache { + var err error + def, err = gts.define(t, map[string]interface{}{ + "type": "record", + }, "") + if err != nil { + return nil, err + } } + // Note: don't start with nil fields because gogen-avro // doesn't like the nil value. fields := []interface{}{} for i := 0; i < t.NumField(); i++ { f := t.Field(i) if f.Anonymous { - return nil, fmt.Errorf("anonymous fields not yet supported (in %s)", t) + if err := gts.schemaForAnonymousField(f, &fields); err != nil { + return nil, err + } + + continue } + // Technically in Go, every field is optional because // that's the way that the encoding/json package works, // so we'll make them all optional. @@ -226,10 +239,33 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { if name == "" { continue } - ftype, err := gts.schemaForGoType(f.Type) + + ftype, err := gts.schemaForGoType(f.Type, false) if err != nil { return nil, err } + + // Check if the same property has already been added by an anonymous struct + exactSameProperty := false + for _, definedField := range fields { + castedDefinedField := definedField.(map[string]interface{}) + + if name == castedDefinedField["name"] { + // If it's also the same type, we can ignore this duplicate + if err := gts.ensureCompatibleTypes(name, ftype, castedDefinedField["type"]); err != nil { + return nil, err + } else { + exactSameProperty = true + } + + break + } + } + + if exactSameProperty { + continue + } + d, err := gts.defaultForType(f.Type) if err != nil { return nil, err @@ -254,7 +290,7 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { if t.Elem().Kind() == reflect.Ptr { return nil, fmt.Errorf("can only cope with a single level of pointer indirection") } - elem, err := gts.schemaForGoType(t.Elem()) + elem, err := gts.schemaForGoType(t.Elem(), false) if err != nil { return nil, err } @@ -270,6 +306,68 @@ func (gts *goTypeSchema) schemaForGoType(t reflect.Type) (interface{}, error) { } } +func (gts *goTypeSchema) schemaForAnonymousField(field reflect.StructField, fields *[]interface{}) error { + // Analyze the Anonymous struct as for others (it will end in the switch case "Struct" in all cases) + anonymousDefinition, err := gts.schemaForGoType(field.Type, true) + if err != nil { + return err + } + + // Cast to process it in the loop + castedAnonymousDefinition := anonymousDefinition.(map[string]interface{}) + castedAnonymousDefinitionFields := castedAnonymousDefinition["fields"].([]interface{}) + + // Merge anonymous fields with the parent ones + for _, definitionField := range castedAnonymousDefinitionFields { + // Before merging we make sure nested anonymous structures do not define an already existing property with different type + // This could come from the parent structure, or a sibling anonymous structure + exactSameProperty := false + for _, parentField := range *fields { + castedDefinitionField := definitionField.(map[string]interface{}) + castedParentField := parentField.(map[string]interface{}) + + if castedDefinitionField["name"] == castedParentField["name"] { + if err := gts.ensureCompatibleTypes(castedDefinitionField["name"].(string), castedDefinitionField["type"], castedParentField["type"]); err != nil { + return err + } else { + exactSameProperty = true + } + + break + } + } + + if exactSameProperty { + continue + } + + *fields = append(*fields, definitionField) + } + + return err +} + +func (gts *goTypeSchema) ensureCompatibleTypes(propertyName string, type1 interface{}, type2 interface{}) error { + // We have to manage the 2 Avro type representations (string and record) + type1Name, ok := type1.(string) + if !ok { + recordType := type1.(map[string]interface{}) + type1Name = recordType["name"].(string) + } + + type2Name, ok := type2.(string) + if !ok { + recordType := type2.(map[string]interface{}) + type2Name = recordType["name"].(string) + } + + if type2Name != type1Name { + return fmt.Errorf("the field %q has already been added by an anonymous structure with a different type (current: %q, defined: %q)", propertyName, type2Name, type1Name) + } + + return nil +} + func (gts *goTypeSchema) define(t reflect.Type, def0 interface{}, defaultName string) (map[string]interface{}, error) { def, ok := def0.(map[string]interface{}) if !ok {