Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ declare module 'egg' {
maxFreeSockets?: number;
}

type Dispatcher = FetchFactory['getDispatcher'] extends () => infer R
? R
: never;

/** HttpClient config */
export interface HttpClientConfig extends HttpClientBaseConfig {
/** http.Agent */
Expand All @@ -319,8 +323,8 @@ declare module 'egg' {
allowH2?: boolean;
/** Custom lookup function for DNS resolution */
lookup?: LookupFunction;
interceptors?: Parameters<Dispatcher['compose']>;
}

export interface EggAppConfig {
workerStartTimeout: number;
baseDir: string;
Expand Down
26 changes: 21 additions & 5 deletions lib/core/fetch_factory.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ const debug = require('util').debuglog('egg:lib:core:fetch_factory');
const mainNodejsVersion = parseInt(process.versions.node.split('.')[0]);
let FetchFactory;
let fetch;
let fetchInitialized = false;
// Track initialization per app instance by storing a WeakMap
const fetchInitializedMap = new WeakMap();
let safeFetch;
let ssrfFetchFactory;

Expand All @@ -14,15 +15,24 @@ if (mainNodejsVersion >= 20) {
FetchFactory = urllib4.FetchFactory;
debug('urllib4 enable');


fetch = function fetch(url, init) {
if (!fetchInitialized) {
fetch = function(url, init) {
if (!fetchInitializedMap.get(this)) {
const clientOptions = {};
if (this.config.httpclient?.lookup) {
clientOptions.lookup = this.config.httpclient.lookup;
}
FetchFactory.setClientOptions(clientOptions);
fetchInitialized = true;

// Support custom interceptors via dispatcher.compose
// Must be set after setClientOptions because setClientOptions resets dispatcher
// interceptors is an array of interceptor functions that follow undici's dispatcher API(undici have not supported clientOptions.interceptors natively yet)
if (this.config.httpclient?.interceptors) {
const interceptors = this.config.httpclient.interceptors;
const originalDispatcher = FetchFactory.getDispatcher();
FetchFactory.setDispatcher(originalDispatcher.compose(interceptors));
}

fetchInitializedMap.set(this, true);
}
return FetchFactory.fetch(url, init);
};
Expand All @@ -41,6 +51,12 @@ if (mainNodejsVersion >= 20) {
}
ssrfFetchFactory = new FetchFactory();
ssrfFetchFactory.setClientOptions(clientOptions);

if (this.config.httpclient?.interceptors) {
const interceptors = this.config.httpclient.interceptors;
const originalDispatcher = ssrfFetchFactory.getDispatcher();
ssrfFetchFactory.setDispatcher(originalDispatcher.compose(interceptors));
}
Comment on lines +55 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for applying interceptors is a duplicate of the logic in the fetch function (lines 28-32). To improve maintainability and avoid code duplication, consider extracting this logic into a shared helper function.

For example, you could define a function like this:

function applyInterceptors(factory, config) {
  if (config.httpclient?.interceptors) {
    const interceptors = config.httpclient.interceptors;
    const originalDispatcher = factory.getDispatcher();
    factory.setDispatcher(originalDispatcher.compose(interceptors));
  }
}

And then call it from both fetch and safeFetch initialization blocks:

// in fetch
applyInterceptors(FetchFactory, this.config);

// in safeFetch
applyInterceptors(ssrfFetchFactory, this.config);

This will make the code cleaner and easier to manage in the future.

}
return ssrfFetchFactory.fetch(url, init);
};
Expand Down
72 changes: 72 additions & 0 deletions test/fixtures/apps/fetch-tracer/app.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
const assert = require('assert');

const TRACE_ID = Symbol('TRACE_ID');
const RPC_ID = Symbol('RPC_ID');

// Simple Tracer implementation
class Tracer {
constructor(traceId, rpcId = '0') {
this.traceId = traceId;
this._rpcId = rpcId;
this._rpcIdSeq = 0;
}

get rpcId() {
return this._rpcId;
}

get rpcIdPlus() {
return `${this._rpcId}.${++this._rpcIdSeq}`;
}
}

module.exports = class TracerApp {
constructor(app) {
this.app = app;
assert(app.config);
// Expose Tracer class for testing
app.Tracer = Tracer;
}

configWillLoad() {
// Setup tracer interceptor using interceptors config
this.app.config.httpclient = this.app.config.httpclient || {};
if (!this.app.FetchFactory) {
return;
}
const tracerConfig = this.app.config.tracer;
const HTTP_HEADER_TRACE_ID_KEY = tracerConfig.HTTP_HEADER_TRACE_ID_KEY.toLowerCase();
const HTTP_HEADER_RPC_ID_KEY = tracerConfig.HTTP_HEADER_RPC_ID_KEY.toLowerCase();

this.app.config.httpclient.interceptors = [
dispatch => {
const app = this.app;
return async function tracerInterceptor(opts, handler) {
const tracer = app.currentContext?.tracer;
let traceId;
let rpcId;

try {
if (tracer) {
traceId = opts.headers[HTTP_HEADER_TRACE_ID_KEY] = tracer.traceId;
rpcId = opts.headers[HTTP_HEADER_RPC_ID_KEY] = tracer.rpcIdPlus;
}
} catch (e) {
e.message = '[egg-tracelog] set tracer header failed: ' + e.message;
app.logger.warn(e);
}

try {
return await dispatch(opts, handler);
} finally {
const opaque = handler.opaque;
if (opaque) {
opaque[TRACE_ID] = traceId;
opaque[RPC_ID] = rpcId;
}
}
};
},
];
}
};
22 changes: 22 additions & 0 deletions test/fixtures/apps/fetch-tracer/app/router.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module.exports = app => {
app.get('/test', async ctx => {
// Mock a tracer on the context using the Tracer class
ctx.tracer = new app.Tracer('test-trace-id-123', '0');

// Store the current context so fetch can access it
app.currentContext = ctx;

// Make a fetch request
const response = await app.fetch(ctx.query.url);

const traceId = response.headers.get('x-trace-id');
if (traceId) ctx.set('x-trace-id', traceId);
const rpcId = response.headers.get('x-rpc-id');
if (rpcId) ctx.set('x-rpc-id', rpcId);

ctx.body = {
status: response.status,
ok: response.ok,
};
});
};
6 changes: 6 additions & 0 deletions test/fixtures/apps/fetch-tracer/config/config.default.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
exports.keys = 'test key';

exports.tracer = {
HTTP_HEADER_TRACE_ID_KEY: 'x-trace-id',
HTTP_HEADER_RPC_ID_KEY: 'x-rpc-id',
};
3 changes: 3 additions & 0 deletions test/fixtures/apps/fetch-tracer/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"name": "fetch-tracer"
}
116 changes: 116 additions & 0 deletions test/lib/core/fetch_tracer.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
const assert = require('node:assert');
const http = require('http');
const utils = require('../../utils');

describe('test/lib/core/fetch_tracer.test.js', () => {
const version = utils.getNodeVersion();
if (version < 20) return;

let app;
let mockServer;

before(async () => {
// Create a mock server to capture headers
mockServer = http.createServer((req, res) => {
const headers = {
'Content-Type': 'application/json',
};
if (req.headers['x-trace-id']) {
headers['x-trace-id'] = req.headers['x-trace-id'];
}
if (req.headers['x-rpc-id']) {
headers['x-rpc-id'] = req.headers['x-rpc-id'];
}

res.writeHead(200, headers);
res.end(JSON.stringify({ ok: true }));
});

await new Promise(resolve => {
mockServer.listen(0, '127.0.0.1', resolve);
});

app = utils.app('apps/fetch-tracer');
await app.ready();
});

after(() => {
if (mockServer?.listening) {
mockServer.close();
}
});

it('should add tracer headers when fetch is called', async () => {
const port = mockServer.address().port;
const targetUrl = `http://127.0.0.1:${port}/mock`;

const response = await app.httpRequest()
.get('/test')
.query({ url: targetUrl })
.expect(200);

assert.strictEqual(response.body.status, 200);
assert.strictEqual(response.body.ok, true);

// Verify tracer headers were added with incremented rpcId
assert.strictEqual(response.headers['x-trace-id'], 'test-trace-id-123');
assert.strictEqual(response.headers['x-rpc-id'], '0.1'); // rpcIdPlus increments from 0
});

it('should work when tracer is not set', async () => {
// Clear currentContext
app.currentContext = null;

const port = mockServer.address().port;
const targetUrl = `http://127.0.0.1:${port}/mock`;

const response = await app.fetch(targetUrl);

assert.strictEqual(response.status, 200);

// Verify no tracer headers when tracer is not set
assert.strictEqual(response.headers.get('x-trace-id'), null);
assert.strictEqual(response.headers.get('x-rpc-id'), null);
});


it('should handle fetch before configDidLoad completes', async () => {
// Test that lazy initialization preserves interceptors set in configDidLoad
const port = mockServer.address().port;
const targetUrl = `http://127.0.0.1:${port}/mock`;

const ctx = app.mockContext();
ctx.tracer = new app.Tracer('early-trace-id', '0.1');
app.currentContext = ctx;

const response = await app.fetch(targetUrl);
assert.strictEqual(response.status, 200);
assert.strictEqual(response.headers.get('x-trace-id'), 'early-trace-id');
assert.strictEqual(response.headers.get('x-rpc-id'), '0.1.1'); // rpcIdPlus increments from 0.1
});

it('should increment rpcId on multiple fetch calls', async () => {
// Test that rpcId increments properly on each fetch
const port = mockServer.address().port;
const targetUrl = `http://127.0.0.1:${port}/mock`;

const ctx = app.mockContext();
ctx.tracer = new app.Tracer('multi-trace-id', '0');
app.currentContext = ctx;

// First fetch
let response = await app.fetch(targetUrl);
assert.strictEqual(response.headers.get('x-trace-id'), 'multi-trace-id');
assert.strictEqual(response.headers.get('x-rpc-id'), '0.1');

// Second fetch
response = await app.fetch(targetUrl);
assert.strictEqual(response.headers.get('x-trace-id'), 'multi-trace-id');
assert.strictEqual(response.headers.get('x-rpc-id'), '0.2');

// Third fetch
response = await app.fetch(targetUrl);
assert.strictEqual(response.headers.get('x-trace-id'), 'multi-trace-id');
assert.strictEqual(response.headers.get('x-rpc-id'), '0.3');
});
});
Loading