diff --git a/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts b/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts index 94f1bfd767d..7fee2d213f6 100644 --- a/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts +++ b/packages/http-client-csharp/emitter/src/lib/client-model-builder.ts @@ -159,4 +159,7 @@ function navigateModels(sdkContext: CSharpEmitterContext) { for (const e of sdkContext.sdkPackage.enums) { fromSdkType(sdkContext, e); } + for (const u of sdkContext.sdkPackage.unions) { + fromSdkType(sdkContext, u); + } } diff --git a/packages/http-client-csharp/emitter/src/lib/type-converter.ts b/packages/http-client-csharp/emitter/src/lib/type-converter.ts index e18c7b33342..160b5acb209 100644 --- a/packages/http-client-csharp/emitter/src/lib/type-converter.ts +++ b/packages/http-client-csharp/emitter/src/lib/type-converter.ts @@ -384,12 +384,46 @@ function fromSdkBuiltInType( }; } -function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): InputUnionType { +function fromUnionType( + sdkContext: CSharpEmitterContext, + union: SdkUnionType, +): InputUnionType | InputModelType { const variantTypes: InputType[] = []; for (const value of union.variantTypes) { const variantType = fromSdkType(sdkContext, value); variantTypes.push(variantType); } + if (isDiscriminatedUnion(union)) { + const baseType: InputModelType = { + kind: "model", + name: union.name, + namespace: union.namespace, + crossLanguageDefinitionId: union.crossLanguageDefinitionId, + access: union.access, + usage: union.usage, + properties: [], + serializationOptions: {}, + summary: union.summary, + doc: union.doc, + deprecation: union.deprecation, + decorators: union.decorators, + external: fromSdkExternalTypeInfo(union), + } as InputModelType; + const discriminatedSubtypes: Record = {}; + variantTypes.forEach((variant) => { + if (variant.kind === "model") { + variant.baseModel = baseType; + if (variant.discriminatorValue !== undefined) { + discriminatedSubtypes[variant.discriminatorValue] = variant; + } + } + }); + if (Object.keys(discriminatedSubtypes).length > 0) { + baseType.discriminatedSubtypes = discriminatedSubtypes; + } + //TODO we should hoist the discriminator property to the base type + return baseType; + } return { kind: "union", @@ -401,6 +435,16 @@ function fromUnionType(sdkContext: CSharpEmitterContext, union: SdkUnionType): I }; } +function isDiscriminatedUnion(sdkType: SdkUnionType): boolean { + if (!sdkType.discriminatedOptions) { + return false; + } + + return sdkType.variantTypes.every((variant) => { + return variant.kind === "model" && !variant.baseModel; + }); +} + function fromSdkConstantType( sdkContext: CSharpEmitterContext, constantType: SdkConstantType, diff --git a/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts b/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts index 081708c903f..443e13ce70a 100644 --- a/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts +++ b/packages/http-client-csharp/emitter/test/Unit/type-converter.test.ts @@ -180,3 +180,52 @@ describe("External types", () => { strictEqual((jsonElementProp.type as any).external.minVersion, "8.0.0"); }); }); + +describe("Union types to model hierarchies", () => { + let runner: TestHost; + + beforeEach(async () => { + runner = await createEmitterTestHost(); + }); + it("should convert union with members to model hierarchy", async () => { + const program = await typeSpecCompile( + ` + model Alpha { + alphaProp: string; + type: "alpha"; + } + model Beta { + betaProp: int32; + type: "beta"; + } + @discriminated(#{ discriminatorPropertyName: "type", envelope: "none" }) + union MyUnion { + "alpha": Alpha, + "beta": Beta + } + op test(@body input: MyUnion): void; + `, + runner, + { IsTCGCNeeded: true }, + ); + const context = createEmitterContext(program); + const sdkContext = await createCSharpSdkContext(context); + const root = createModel(sdkContext); + + const alphaModel = root.models.find((m) => m.name === "Alpha"); + ok(alphaModel, "Alpha should exist"); + + const betaModel = root.models.find((m) => m.name === "Beta"); + ok(betaModel, "Beta should exist"); + + const myUnion = root.models.find((m) => m.name === "MyUnion"); + ok(myUnion, "MyUnion should exist"); + + // Validate that MyUnion is a model + strictEqual(myUnion.kind, "model", "MyUnion should be converted to a model"); + + // Validate that Alpha and Beta inherit from MyUnion + strictEqual(alphaModel.baseModel, myUnion, "Alpha should inherit from MyUnion"); + strictEqual(betaModel.baseModel, myUnion, "Beta should inherit from MyUnion"); + }); +});