Skip to content

Commit

Permalink
feat(vertexai): add json schema support and update schema structure (#…
Browse files Browse the repository at this point in the history
…13378)

* Add schema support for prompt response

* update schema object, replace requiredProperties to optionalProperties

* api updates (breaking)

* more tweak during api review

* Add format for Schema.string

* fix test files

* Update packages/firebase_vertexai/firebase_vertexai/lib/src/vertex_api.dart

Co-authored-by: Nate Bosch <[email protected]>

* add review feedback

---------

Co-authored-by: Nate Bosch <[email protected]>
  • Loading branch information
cynthiajoan and natebosch authored Sep 24, 2024
1 parent 68b0b14 commit 5ffb204
Show file tree
Hide file tree
Showing 11 changed files with 482 additions and 437 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: firebase_app_check_web
description: The web implementation of firebase_auth
description: The web implementation of firebase_app_check
homepage: https://github.com/firebase/flutterfire/tree/main/packages/firebase_app_check/firebase_app_check_web
version: 0.1.3

Expand Down
81 changes: 74 additions & 7 deletions packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ class _ChatWidgetState extends State<ChatWidget> {

initFirebase().then((value) {
_model = FirebaseVertexAI.instance.generativeModel(
model: 'gemini-1.5-flash-preview-0514',
model: 'gemini-1.5-flash-001',
);
_functionCallModel = FirebaseVertexAI.instance.generativeModel(
model: 'gemini-1.5-flash-preview-0514',
model: 'gemini-1.5-flash-001',
tools: [
Tool(functionDeclarations: [exchangeRateTool]),
],
Expand Down Expand Up @@ -274,6 +274,20 @@ class _ChatWidgetState extends State<ChatWidget> {
: Theme.of(context).colorScheme.primary,
),
),
IconButton(
tooltip: 'schema prompt',
onPressed: !_loading
? () async {
await _promptSchemaTest(_textController.text);
}
: null,
icon: Icon(
Icons.schema,
color: _loading
? Theme.of(context).colorScheme.secondary
: Theme.of(context).colorScheme.primary,
),
),
if (!_loading)
IconButton(
onPressed: () async {
Expand Down Expand Up @@ -304,6 +318,51 @@ class _ChatWidgetState extends State<ChatWidget> {
);
}

Future<void> _promptSchemaTest(String subject) async {
setState(() {
_loading = true;
});
try {
final content = [Content.text('Create a list of 20 $subject.')];

final response = await _model.generateContent(
content,
generationConfig: GenerationConfig(
responseMimeType: 'application/json',
responseSchema: Schema.array(
items: Schema.string(
description: 'A single word that a player will need to guess.',
),
),
),
);

var text = response.text;
_generatedContent.add((image: null, text: text, fromUser: false));

if (text == null) {
_showError('No response from API.');
return;
} else {
setState(() {
_loading = false;
_scrollDown();
});
}
} catch (e) {
_showError(e.toString());
setState(() {
_loading = false;
});
} finally {
_textController.clear();
setState(() {
_loading = false;
});
_textFieldFocus.requestFocus();
}
}

Future<void> _sendStorageUriPrompt(String message) async {
setState(() {
_loading = true;
Expand Down Expand Up @@ -483,11 +542,19 @@ class _ChatWidgetState extends State<ChatWidget> {
});

const prompt = 'tell a short story';
var response = await _model.countTokens([Content.text(prompt)]);
print(
'token: ${response.totalTokens}, billable characters: ${response.totalBillableCharacters}',
);

var content = Content.text(prompt);
var tokenResponse = await _model.countTokens([content]);
final tokenResult =
'Count token: ${tokenResponse.totalTokens}, billable characters: ${tokenResponse.totalBillableCharacters}';
_generatedContent.add((image: null, text: tokenResult, fromUser: false));

var contentResponse = await _model.generateContent([content]);
final contentMetaData =
'result metadata, promptTokenCount:${contentResponse.usageMetadata!.promptTokenCount}, '
'candidatesTokenCount:${contentResponse.usageMetadata!.candidatesTokenCount}, '
'totalTokenCount:${contentResponse.usageMetadata!.totalTokenCount}';
_generatedContent
.add((image: null, text: contentMetaData, fromUser: false));
setState(() {
_loading = false;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

export 'src/firebase_vertexai.dart' show FirebaseVertexAI, RequestOptions;
export 'src/firebase_vertexai.dart' show FirebaseVertexAI;
export 'src/vertex_api.dart'
show
BatchEmbedContentsResponse,
Expand All @@ -33,7 +33,8 @@ export 'src/vertex_api.dart'
PromptFeedback,
SafetyRating,
SafetySetting,
TaskType;
TaskType,
UsageMetadata;
export 'src/vertex_chat.dart' show ChatSession, StartChatExtension;
export 'src/vertex_content.dart'
show
Expand All @@ -56,8 +57,7 @@ export 'src/vertex_function_calling.dart'
FunctionCallingConfig,
FunctionCallingMode,
FunctionDeclaration,
Schema,
SchemaType,
Tool,
ToolConfig;
export 'src/vertex_model.dart' show GenerativeModel;
export 'src/vertex_schema.dart' show Schema, SchemaType;
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,10 @@ import 'vertex_model.dart';

const _defaultLocation = 'us-central1';

/// Default timeout duration, 30 minutes in millisecond
const int defaultTimeout = 1800000;

/// The entrypoint for [FirebaseVertexAI].
class FirebaseVertexAI extends FirebasePluginPlatform {
FirebaseVertexAI._(
{required this.app,
required this.options,
required this.location,
this.appCheck,
this.auth})
{required this.app, required this.location, this.appCheck, this.auth})
: super(app.name, 'plugins.flutter.io/firebase_vertexai');

/// The [FirebaseApp] for this current [FirebaseVertexAI] instance.
Expand All @@ -48,9 +41,6 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
/// The optional [FirebaseAuth] for this current [FirebaseVertexAI] instance.
FirebaseAuth? auth;

/// Configuration parameters for sending requests to the backend.
RequestOptions options;

/// The service location for this [FirebaseVertexAI] instance.
String location;

Expand All @@ -71,7 +61,6 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
FirebaseApp? app,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
RequestOptions? options,
String? location,
}) {
app ??= Firebase.app();
Expand All @@ -80,17 +69,10 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
return _cachedInstances[app.name]!;
}

options ??=
RequestOptions(timeout: const Duration(milliseconds: defaultTimeout));

location ??= _defaultLocation;

FirebaseVertexAI newInstance = FirebaseVertexAI._(
app: app,
options: options,
location: location,
appCheck: appCheck,
auth: auth);
app: app, location: location, appCheck: appCheck, auth: auth);
_cachedInstances[app.name] = newInstance;

return newInstance;
Expand All @@ -107,34 +89,25 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
/// The optional [safetySettings] and [generationConfig] can be used to
/// control and guide the generation. See [SafetySetting] and
/// [GenerationConfig] for details.
GenerativeModel generativeModel(
{required String model,
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
Content? systemInstruction,
List<Tool>? tools,
ToolConfig? toolConfig}) {
GenerativeModel generativeModel({
required String model,
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
}) {
return createGenerativeModel(
model: model,
app: app,
appCheck: appCheck,
auth: auth,
location: location,
safetySettings: safetySettings,
generationConfig: generationConfig,
systemInstruction: systemInstruction,
tools: tools,
toolConfig: toolConfig);
model: model,
app: app,
appCheck: appCheck,
auth: auth,
location: location,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
);
}
}

/// Options for request to backend.
class RequestOptions {
/// [timeout] duration for the request.
RequestOptions({
required this.timeout,
});

/// Timeout for the request, default to 30 minutes, in milliseconds.
final Duration timeout;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import 'vertex_content.dart';
import 'vertex_error.dart';
import 'vertex_schema.dart';

/// Response for Count Tokens
final class CountTokensResponse {
Expand Down Expand Up @@ -181,7 +182,7 @@ final class PromptFeedback {
/// Metadata on the generation request's token usage.
final class UsageMetadata {
/// Constructor
UsageMetadata({
UsageMetadata._({
this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
Expand Down Expand Up @@ -544,12 +545,13 @@ final class GenerationConfig {
/// Constructor
GenerationConfig(
{this.candidateCount,
this.stopSequences = const [],
this.stopSequences,
this.maxOutputTokens,
this.temperature,
this.topP,
this.topK,
this.responseMimeType});
this.responseMimeType,
this.responseSchema});

/// Number of generated responses to return.
///
Expand All @@ -561,7 +563,7 @@ final class GenerationConfig {
///
/// If specified, the API will stop at the first appearance of a stop
/// sequence. The stop sequence will not be included as part of the response.
final List<String> stopSequences;
final List<String>? stopSequences;

/// The maximum number of tokens to include in a candidate.
///
Expand Down Expand Up @@ -603,18 +605,28 @@ final class GenerationConfig {
/// - `application/json`: JSON response in the candidates.
final String? responseMimeType;

/// Output response schema of the generated candidate text.
///
/// - Note: This only applies when the [responseMimeType] supports
/// a schema; currently this is limited to `application/json`.
final Schema? responseSchema;

/// Convert to json format
Map<String, Object?> toJson() => {
if (candidateCount case final candidateCount?)
'candidateCount': candidateCount,
if (stopSequences.isNotEmpty) 'stopSequences': stopSequences,
if (stopSequences case final stopSequences?
when stopSequences.isNotEmpty)
'stopSequences': stopSequences,
if (maxOutputTokens case final maxOutputTokens?)
'maxOutputTokens': maxOutputTokens,
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
if (responseMimeType case final responseMimeType?)
'responseMimeType': responseMimeType,
if (responseSchema case final responseSchema?)
'responseSchema': responseSchema,
};
}

Expand Down Expand Up @@ -786,7 +798,7 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
return UsageMetadata(
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount);
Expand Down
Loading

0 comments on commit 5ffb204

Please sign in to comment.