Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;

// This is initialized once for the lifetime of the application. This enables re-using
// channels to S2A.
private static volatile ChannelCredentials s2aChannelCredentials;

/**
* Resets the s2aChannelCredentials of the {@link InstantiatingGrpcChannelProvider} class for
* testing purposes.
*
* <p>This should only be called from tests.
*/
@VisibleForTesting
static void resetS2AChannelCredentials() {
synchronized (InstantiatingGrpcChannelProvider.class) {
s2aChannelCredentials = null;
}
}

/*
* Experimental feature
*
Expand Down Expand Up @@ -595,43 +612,60 @@ ChannelCredentials createPlaintextToS2AChannelCredentials(String plaintextAddres
* @return {@link ChannelCredentials} configured to use S2A to create mTLS connection.
*/
ChannelCredentials createS2ASecuredChannelCredentials() {
SecureSessionAgentConfig config = s2aConfigProvider.getConfig();
String plaintextAddress = config.getPlaintextAddress();
String mtlsAddress = config.getMtlsAddress();
if (Strings.isNullOrEmpty(mtlsAddress)) {
// Fallback to plaintext connection to S2A.
LOG.log(
Level.INFO,
"Cannot establish an mTLS connection to S2A because autoconfig endpoint did not return a mtls address to reach S2A.");
return createPlaintextToS2AChannelCredentials(plaintextAddress);
}
// Currently, MTLS to MDS is only available on GCE. See:
// https://cloud.google.com/compute/docs/metadata/overview#https-mds
// Try to load MTLS-MDS creds.
File rootFile = new File(MTLS_MDS_ROOT_PATH);
File certKeyFile = new File(MTLS_MDS_CERT_CHAIN_AND_KEY_PATH);
if (rootFile.isFile() && certKeyFile.isFile()) {
// Try to connect to S2A using mTLS.
ChannelCredentials mtlsToS2AChannelCredentials = null;
try {
mtlsToS2AChannelCredentials =
createMtlsToS2AChannelCredentials(rootFile, certKeyFile, certKeyFile);
} catch (IOException ignore) {
// Fallback to plaintext-to-S2A connection on error.
LOG.log(
Level.WARNING,
"Cannot establish an mTLS connection to S2A due to error creating MTLS to MDS TlsChannelCredentials credentials, falling back to plaintext connection to S2A: "
+ ignore.getMessage());
return createPlaintextToS2AChannelCredentials(plaintextAddress);
if (s2aChannelCredentials == null) {
// s2aChannelCredentials is initialized once and shared by all instances of the class.
// To prevent a race on initialization, the object initialization is synchronized on the class
// object.
synchronized (InstantiatingGrpcChannelProvider.class) {
if (s2aChannelCredentials != null) {
return s2aChannelCredentials;
}
SecureSessionAgentConfig config = s2aConfigProvider.getConfig();
String plaintextAddress = config.getPlaintextAddress();
String mtlsAddress = config.getMtlsAddress();
if (Strings.isNullOrEmpty(mtlsAddress)) {
// Fallback to plaintext connection to S2A.
LOG.log(
Level.INFO,
"Cannot establish an mTLS connection to S2A because autoconfig endpoint did not return a mtls address to reach S2A.");
s2aChannelCredentials = createPlaintextToS2AChannelCredentials(plaintextAddress);
return s2aChannelCredentials;
}
// Currently, MTLS to MDS is only available on GCE. See:
// https://cloud.google.com/compute/docs/metadata/overview#https-mds
// Try to load MTLS-MDS creds.
File rootFile = new File(MTLS_MDS_ROOT_PATH);
File certKeyFile = new File(MTLS_MDS_CERT_CHAIN_AND_KEY_PATH);
if (rootFile.isFile() && certKeyFile.isFile()) {
// Try to connect to S2A using mTLS.
ChannelCredentials mtlsToS2AChannelCredentials = null;
try {
mtlsToS2AChannelCredentials =
createMtlsToS2AChannelCredentials(rootFile, certKeyFile, certKeyFile);
} catch (IOException ignore) {
// Fallback to plaintext-to-S2A connection on error.
LOG.log(
Level.WARNING,
"Cannot establish an mTLS connection to S2A due to error creating MTLS to MDS TlsChannelCredentials credentials, falling back to plaintext connection to S2A: "
+ ignore.getMessage());
s2aChannelCredentials = createPlaintextToS2AChannelCredentials(plaintextAddress);
return s2aChannelCredentials;
}
s2aChannelCredentials =
buildS2AChannelCredentials(mtlsAddress, mtlsToS2AChannelCredentials);
return s2aChannelCredentials;
} else {
// Fallback to plaintext-to-S2A connection if MTLS-MDS creds do not exist.
LOG.log(
Level.INFO,
"Cannot establish an mTLS connection to S2A because MTLS to MDS credentials do not"
+ " exist on filesystem, falling back to plaintext connection to S2A");
s2aChannelCredentials = createPlaintextToS2AChannelCredentials(plaintextAddress);
return s2aChannelCredentials;
}
}
return buildS2AChannelCredentials(mtlsAddress, mtlsToS2AChannelCredentials);
} else {
// Fallback to plaintext-to-S2A connection if MTLS-MDS creds do not exist.
LOG.log(
Level.INFO,
"Cannot establish an mTLS connection to S2A because MTLS to MDS credentials do not exist on filesystem, falling back to plaintext connection to S2A");
return createPlaintextToS2AChannelCredentials(plaintextAddress);
}
return s2aChannelCredentials;
}

private ManagedChannel createSingleChannel() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,7 @@ void createS2ASecuredChannelCredentials_bothS2AAddressesNull_returnsNull() {
.setS2AConfigProvider(s2aConfigProvider)
.build();
assertThat(provider.createS2ASecuredChannelCredentials()).isNull();
InstantiatingGrpcChannelProvider.resetS2AChannelCredentials();
}

@Test
Expand All @@ -1175,6 +1176,28 @@ void createS2ASecuredChannelCredentials_bothS2AAddressesNull_returnsNull() {
.contains(
"Cannot establish an mTLS connection to S2A because autoconfig endpoint did not return a mtls address to reach S2A.");
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
InstantiatingGrpcChannelProvider.resetS2AChannelCredentials();
}

@Test
void
createTwoS2ASecuredChannelCredentials_mtlsS2AAddressNull_returnsSamePlaintextToS2AS2AChannelCredentials() {
SecureSessionAgent s2aConfigProvider = Mockito.mock(SecureSessionAgent.class);
SecureSessionAgentConfig config =
SecureSessionAgentConfig.createBuilder().setPlaintextAddress("localhost:8080").build();
Mockito.when(s2aConfigProvider.getConfig()).thenReturn(config);
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setS2AConfigProvider(s2aConfigProvider)
.build();
assertThat(provider.createS2ASecuredChannelCredentials()).isNotNull();
InstantiatingGrpcChannelProvider provider2 =
InstantiatingGrpcChannelProvider.newBuilder()
.setS2AConfigProvider(s2aConfigProvider)
.build();
assertThat(provider2.createS2ASecuredChannelCredentials()).isNotNull();
assertEquals(provider, provider2);
InstantiatingGrpcChannelProvider.resetS2AChannelCredentials();
}

@Test
Expand All @@ -1197,6 +1220,7 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
.contains(
"Cannot establish an mTLS connection to S2A because MTLS to MDS credentials do not exist on filesystem, falling back to plaintext connection to S2A");
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
InstantiatingGrpcChannelProvider.resetS2AChannelCredentials();
}

@Test
Expand Down
Loading