Skip to content

Commit

Permalink
feat (provider/openai): predicted outputs token usage (#4252)
Browse files Browse the repository at this point in the history
Co-authored-by: Gwenaël Gallon <[email protected]>
  • Loading branch information
lgrammel and ggallon authored Jan 3, 2025
1 parent 88eec24 commit b19aa82
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 42 deletions.
5 changes: 5 additions & 0 deletions .changeset/nasty-beers-rule.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/openai': patch
---

feat (provider/openai): add predicted outputs token usage
16 changes: 16 additions & 0 deletions content/providers/01-ai-sdk-providers/01-openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,22 @@ const result = streamText({
});
```

OpenAI provides usage information for predicted outputs (`acceptedPredictionTokens` and `rejectedPredictionTokens`).
You can access it in the `experimental_providerMetadata` object.

```ts highlight="11"
const openaiMetadata = (await result.experimental_providerMetadata)?.openai;

const acceptedPredictionTokens = openaiMetadata?.acceptedPredictionTokens;
const rejectedPredictionTokens = openaiMetadata?.rejectedPredictionTokens;
```

<Note type="warning">
OpenAI Predicted Outputs have several
[limitations](https://platform.openai.com/docs/guides/predicted-outputs#limitations),
e.g. unsupported API parameters and no tool calling support.
</Note>

#### Image Detail

You can use the `openai` provider metadata to set the [image generation detail](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding) to `high`, `low`, or `auto`:
Expand Down
9 changes: 8 additions & 1 deletion examples/ai-core/src/stream-text/openai-predicted-output.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,15 @@ async function main() {
process.stdout.write(textPart);
}

const usage = await result.usage;
const openaiMetadata = (await result.experimental_providerMetadata)?.openai;

console.log();
console.log('Token usage:', await result.usage);
console.log('Token usage:', {
...usage,
acceptedPredictionTokens: openaiMetadata?.acceptedPredictionTokens,
rejectedPredictionTokens: openaiMetadata?.rejectedPredictionTokens,
});
console.log('Finish reason:', await result.finishReason);
}

Expand Down
99 changes: 91 additions & 8 deletions packages/openai/src/openai-chat-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ describe('doGenerate', () => {
completion_tokens?: number;
completion_tokens_details?: {
reasoning_tokens?: number;
accepted_prediction_tokens?: number;
rejected_prediction_tokens?: number;
};
prompt_tokens_details?: {
cached_tokens?: number;
Expand Down Expand Up @@ -1088,6 +1090,35 @@ describe('doGenerate', () => {
});
});

it('should return accepted_prediction_tokens and rejected_prediction_tokens in completion_details_tokens', async () => {
prepareJsonResponse({
usage: {
prompt_tokens: 15,
completion_tokens: 20,
total_tokens: 35,
completion_tokens_details: {
accepted_prediction_tokens: 123,
rejected_prediction_tokens: 456,
},
},
});

const model = provider.chat('gpt-4o-mini');

const result = await model.doGenerate({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(result.providerMetadata).toStrictEqual({
openai: {
acceptedPredictionTokens: 123,
rejectedPredictionTokens: 456,
},
});
});

describe('reasoning models', () => {
it('should clear out temperature, top_p, frequency_penalty, presence_penalty', async () => {
prepareJsonResponse();
Expand Down Expand Up @@ -1262,7 +1293,9 @@ describe('doStream', () => {
cached_tokens?: number;
};
completion_tokens_details?: {
reasoning_tokens: number;
reasoning_tokens?: number;
accepted_prediction_tokens?: number;
rejected_prediction_tokens?: number;
};
};
logprobs?: {
Expand Down Expand Up @@ -1333,6 +1366,7 @@ describe('doStream', () => {
finishReason: 'stop',
logprobs: mapOpenAIChatLogProbsOutput(TEST_LOGPROBS),
usage: { promptTokens: 17, completionTokens: 227 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1460,6 +1494,7 @@ describe('doStream', () => {
finishReason: 'tool-calls',
logprobs: undefined,
usage: { promptTokens: 53, completionTokens: 17 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1594,6 +1629,7 @@ describe('doStream', () => {
finishReason: 'tool-calls',
logprobs: undefined,
usage: { promptTokens: 53, completionTokens: 17 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1717,6 +1753,7 @@ describe('doStream', () => {
finishReason: 'tool-calls',
logprobs: undefined,
usage: { promptTokens: 226, completionTokens: 20 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1781,6 +1818,7 @@ describe('doStream', () => {
finishReason: 'tool-calls',
logprobs: undefined,
usage: { promptTokens: 53, completionTokens: 17 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1857,6 +1895,7 @@ describe('doStream', () => {
finishReason: 'stop',
logprobs: undefined,
usage: { promptTokens: 53, completionTokens: 17 },
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -1895,6 +1934,7 @@ describe('doStream', () => {
completionTokens: NaN,
promptTokens: NaN,
},
providerMetadata: { openai: {} },
},
]);
});
Expand All @@ -1920,6 +1960,7 @@ describe('doStream', () => {
completionTokens: NaN,
promptTokens: NaN,
},
providerMetadata: { openai: {} },
});
});

Expand Down Expand Up @@ -2011,7 +2052,7 @@ describe('doStream', () => {
});
});

it('should handle cached tokens in experimental_providerMetadata', async () => {
it('should return cached tokens in providerMetadata', async () => {
prepareStreamResponse({
content: [],
usage: {
Expand All @@ -2037,10 +2078,7 @@ describe('doStream', () => {
messages: [{ role: 'user', content: 'Hello' }],
});

const chunksArr = await convertReadableStreamToArray(stream);
expect(chunksArr[chunksArr.length - 1]).toHaveProperty('providerMetadata');
expect(chunksArr[chunksArr.length - 1].type).toEqual('finish');
expect(chunksArr[chunksArr.length - 1]).toStrictEqual({
expect((await convertReadableStreamToArray(stream)).at(-1)).toStrictEqual({
type: 'finish',
finishReason: 'stop',
logprobs: undefined,
Expand All @@ -2054,6 +2092,50 @@ describe('doStream', () => {
});
});

it('should return accepted_prediction_tokens and rejected_prediction_tokens in providerMetadata', async () => {
prepareStreamResponse({
content: [],
usage: {
prompt_tokens: 15,
completion_tokens: 20,
total_tokens: 35,
completion_tokens_details: {
accepted_prediction_tokens: 123,
rejected_prediction_tokens: 456,
},
},
});

const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: { type: 'regular' },
prompt: TEST_PROMPT,
});

expect(await server.getRequestBodyJson()).toStrictEqual({
stream: true,
stream_options: { include_usage: true },
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: 'Hello' }],
});

expect((await convertReadableStreamToArray(stream)).at(-1)).toStrictEqual({
type: 'finish',
finishReason: 'stop',
logprobs: undefined,
usage: {
promptTokens: 15,
completionTokens: 20,
},
providerMetadata: {
openai: {
acceptedPredictionTokens: 123,
rejectedPredictionTokens: 456,
},
},
});
});

it('should send store extension setting', async () => {
prepareStreamResponse({ content: [] });

Expand Down Expand Up @@ -2133,6 +2215,7 @@ describe('doStream', () => {
finishReason: 'stop',
usage: { promptTokens: 17, completionTokens: 227 },
logprobs: undefined,
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -2292,7 +2375,7 @@ describe('doStream simulated streaming', () => {
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 30 },
logprobs: undefined,
providerMetadata: undefined,
providerMetadata: { openai: {} },
},
]);
});
Expand Down Expand Up @@ -2356,7 +2439,7 @@ describe('doStream simulated streaming', () => {
finishReason: 'stop',
usage: { promptTokens: 4, completionTokens: 30 },
logprobs: undefined,
providerMetadata: undefined,
providerMetadata: { openai: {} },
},
]);
});
Expand Down
82 changes: 49 additions & 33 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,20 +309,25 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
const { messages: rawPrompt, ...rawSettings } = body;
const choice = response.choices[0];

let providerMetadata: LanguageModelV1ProviderMetadata | undefined;
if (
response.usage?.completion_tokens_details?.reasoning_tokens != null ||
response.usage?.prompt_tokens_details?.cached_tokens != null
) {
providerMetadata = { openai: {} };
if (response.usage?.completion_tokens_details?.reasoning_tokens != null) {
providerMetadata.openai.reasoningTokens =
response.usage?.completion_tokens_details?.reasoning_tokens;
}
if (response.usage?.prompt_tokens_details?.cached_tokens != null) {
providerMetadata.openai.cachedPromptTokens =
response.usage?.prompt_tokens_details?.cached_tokens;
}
// provider metadata:
const completionTokenDetails = response.usage?.completion_tokens_details;
const promptTokenDetails = response.usage?.prompt_tokens_details;
const providerMetadata: LanguageModelV1ProviderMetadata = { openai: {} };
if (completionTokenDetails?.reasoning_tokens != null) {
providerMetadata.openai.reasoningTokens =
completionTokenDetails?.reasoning_tokens;
}
if (completionTokenDetails?.accepted_prediction_tokens != null) {
providerMetadata.openai.acceptedPredictionTokens =
completionTokenDetails?.accepted_prediction_tokens;
}
if (completionTokenDetails?.rejected_prediction_tokens != null) {
providerMetadata.openai.rejectedPredictionTokens =
completionTokenDetails?.rejected_prediction_tokens;
}
if (promptTokenDetails?.cached_tokens != null) {
providerMetadata.openai.cachedPromptTokens =
promptTokenDetails?.cached_tokens;
}

return {
Expand Down Expand Up @@ -451,7 +456,8 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {

const { useLegacyFunctionCalling } = this.settings;

let providerMetadata: LanguageModelV1ProviderMetadata | undefined;
const providerMetadata: LanguageModelV1ProviderMetadata = { openai: {} };

return {
stream: response.pipeThrough(
new TransformStream<
Expand Down Expand Up @@ -485,29 +491,37 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
}

if (value.usage != null) {
usage = {
promptTokens: value.usage.prompt_tokens ?? undefined,
completionTokens: value.usage.completion_tokens ?? undefined,
};

const {
completion_tokens_details: completionTokenDetails,
prompt_tokens_details: promptTokenDetails,
prompt_tokens,
completion_tokens,
prompt_tokens_details,
completion_tokens_details,
} = value.usage;

usage = {
promptTokens: prompt_tokens ?? undefined,
completionTokens: completion_tokens ?? undefined,
};

if (completion_tokens_details?.reasoning_tokens != null) {
providerMetadata.openai.reasoningTokens =
completion_tokens_details?.reasoning_tokens;
}
if (
completionTokenDetails?.reasoning_tokens != null ||
promptTokenDetails?.cached_tokens != null
completion_tokens_details?.accepted_prediction_tokens != null
) {
providerMetadata = { openai: {} };
if (completionTokenDetails?.reasoning_tokens != null) {
providerMetadata.openai.reasoningTokens =
completionTokenDetails?.reasoning_tokens;
}
if (promptTokenDetails?.cached_tokens != null) {
providerMetadata.openai.cachedPromptTokens =
promptTokenDetails?.cached_tokens;
}
providerMetadata.openai.acceptedPredictionTokens =
completion_tokens_details?.accepted_prediction_tokens;
}
if (
completion_tokens_details?.rejected_prediction_tokens != null
) {
providerMetadata.openai.rejectedPredictionTokens =
completion_tokens_details?.rejected_prediction_tokens;
}
if (prompt_tokens_details?.cached_tokens != null) {
providerMetadata.openai.cachedPromptTokens =
prompt_tokens_details?.cached_tokens;
}
}

Expand Down Expand Up @@ -695,6 +709,8 @@ const openaiTokenUsageSchema = z
completion_tokens_details: z
.object({
reasoning_tokens: z.number().nullish(),
accepted_prediction_tokens: z.number().nullish(),
rejected_prediction_tokens: z.number().nullish(),
})
.nullish(),
})
Expand Down

0 comments on commit b19aa82

Please sign in to comment.