Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vertexai): add json schema support and update schema structure #13378

Merged
merged 9 commits into from
Sep 24, 2024
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: firebase_app_check_web
description: The web implementation of firebase_auth
description: The web implementation of firebase_app_check
homepage: https://github.com/firebase/flutterfire/tree/main/packages/firebase_app_check/firebase_app_check_web
version: 0.1.3

Expand Down
81 changes: 74 additions & 7 deletions packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ class _ChatWidgetState extends State<ChatWidget> {

initFirebase().then((value) {
_model = FirebaseVertexAI.instance.generativeModel(
model: 'gemini-1.5-flash-preview-0514',
model: 'gemini-1.5-flash-001',
);
_functionCallModel = FirebaseVertexAI.instance.generativeModel(
model: 'gemini-1.5-flash-preview-0514',
model: 'gemini-1.5-flash-001',
tools: [
Tool(functionDeclarations: [exchangeRateTool]),
],
Expand Down Expand Up @@ -274,6 +274,20 @@ class _ChatWidgetState extends State<ChatWidget> {
: Theme.of(context).colorScheme.primary,
),
),
IconButton(
tooltip: 'schema prompt',
onPressed: !_loading
? () async {
await _promptSchemaTest(_textController.text);
}
: null,
icon: Icon(
Icons.schema,
color: _loading
? Theme.of(context).colorScheme.secondary
: Theme.of(context).colorScheme.primary,
),
),
if (!_loading)
IconButton(
onPressed: () async {
Expand Down Expand Up @@ -304,6 +318,51 @@ class _ChatWidgetState extends State<ChatWidget> {
);
}

Future<void> _promptSchemaTest(String subject) async {
setState(() {
_loading = true;
});
try {
final content = [Content.text('Create a list of 20 $subject.')];

final response = await _model.generateContent(
content,
generationConfig: GenerationConfig(
responseMimeType: 'application/json',
responseSchema: Schema.array(
items: Schema.string(
description: 'A single word that a player will need to guess.',
),
),
),
);

var text = response.text;
_generatedContent.add((image: null, text: text, fromUser: false));

if (text == null) {
_showError('No response from API.');
return;
} else {
setState(() {
_loading = false;
_scrollDown();
});
}
} catch (e) {
_showError(e.toString());
setState(() {
_loading = false;
});
} finally {
_textController.clear();
setState(() {
_loading = false;
});
_textFieldFocus.requestFocus();
}
}

Future<void> _sendStorageUriPrompt(String message) async {
setState(() {
_loading = true;
Expand Down Expand Up @@ -483,11 +542,19 @@ class _ChatWidgetState extends State<ChatWidget> {
});

const prompt = 'tell a short story';
var response = await _model.countTokens([Content.text(prompt)]);
print(
'token: ${response.totalTokens}, billable characters: ${response.totalBillableCharacters}',
);

var content = Content.text(prompt);
var tokenResponse = await _model.countTokens([content]);
final tokenResult =
'Count token: ${tokenResponse.totalTokens}, billable characters: ${tokenResponse.totalBillableCharacters}';
_generatedContent.add((image: null, text: tokenResult, fromUser: false));

var contentResponse = await _model.generateContent([content]);
final contentMetaData =
'result metadata, promptTokenCount:${contentResponse.usageMetadata!.promptTokenCount}, '
'candidatesTokenCount:${contentResponse.usageMetadata!.candidatesTokenCount}, '
'totalTokenCount:${contentResponse.usageMetadata!.totalTokenCount}';
_generatedContent
.add((image: null, text: contentMetaData, fromUser: false));
setState(() {
_loading = false;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

export 'src/firebase_vertexai.dart' show FirebaseVertexAI, RequestOptions;
export 'src/firebase_vertexai.dart' show FirebaseVertexAI;
export 'src/vertex_api.dart'
show
BatchEmbedContentsResponse,
Expand All @@ -33,7 +33,8 @@ export 'src/vertex_api.dart'
PromptFeedback,
SafetyRating,
SafetySetting,
TaskType;
TaskType,
UsageMetadata;
export 'src/vertex_chat.dart' show ChatSession, StartChatExtension;
export 'src/vertex_content.dart'
show
Expand All @@ -56,8 +57,7 @@ export 'src/vertex_function_calling.dart'
FunctionCallingConfig,
FunctionCallingMode,
FunctionDeclaration,
Schema,
SchemaType,
Tool,
ToolConfig;
export 'src/vertex_model.dart' show GenerativeModel;
export 'src/vertex_schema.dart' show Schema, SchemaType;
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,10 @@ import 'vertex_model.dart';

const _defaultLocation = 'us-central1';

/// Default timeout duration, 30 minutes in millisecond
const int defaultTimeout = 1800000;

/// The entrypoint for [FirebaseVertexAI].
class FirebaseVertexAI extends FirebasePluginPlatform {
FirebaseVertexAI._(
{required this.app,
required this.options,
required this.location,
this.appCheck,
this.auth})
{required this.app, required this.location, this.appCheck, this.auth})
: super(app.name, 'plugins.flutter.io/firebase_vertexai');

/// The [FirebaseApp] for this current [FirebaseVertexAI] instance.
Expand All @@ -48,9 +41,6 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
/// The optional [FirebaseAuth] for this current [FirebaseVertexAI] instance.
FirebaseAuth? auth;

/// Configuration parameters for sending requests to the backend.
RequestOptions options;

/// The service location for this [FirebaseVertexAI] instance.
String location;

Expand All @@ -71,7 +61,6 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
FirebaseApp? app,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
RequestOptions? options,
String? location,
}) {
app ??= Firebase.app();
Expand All @@ -80,17 +69,10 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
return _cachedInstances[app.name]!;
}

options ??=
RequestOptions(timeout: const Duration(milliseconds: defaultTimeout));

location ??= _defaultLocation;

FirebaseVertexAI newInstance = FirebaseVertexAI._(
app: app,
options: options,
location: location,
appCheck: appCheck,
auth: auth);
app: app, location: location, appCheck: appCheck, auth: auth);
_cachedInstances[app.name] = newInstance;

return newInstance;
Expand All @@ -107,34 +89,25 @@ class FirebaseVertexAI extends FirebasePluginPlatform {
/// The optional [safetySettings] and [generationConfig] can be used to
/// control and guide the generation. See [SafetySetting] and
/// [GenerationConfig] for details.
GenerativeModel generativeModel(
{required String model,
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
Content? systemInstruction,
List<Tool>? tools,
ToolConfig? toolConfig}) {
GenerativeModel generativeModel({
required String model,
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
Content? systemInstruction,
}) {
return createGenerativeModel(
model: model,
app: app,
appCheck: appCheck,
auth: auth,
location: location,
safetySettings: safetySettings,
generationConfig: generationConfig,
systemInstruction: systemInstruction,
tools: tools,
toolConfig: toolConfig);
model: model,
app: app,
appCheck: appCheck,
auth: auth,
location: location,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
);
}
}

/// Options for request to backend.
class RequestOptions {
/// [timeout] duration for the request.
RequestOptions({
required this.timeout,
});

/// Timeout for the request, default to 30 minutes, in milliseconds.
final Duration timeout;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import 'vertex_content.dart';
import 'vertex_error.dart';
import 'vertex_schema.dart';

/// Response for Count Tokens
final class CountTokensResponse {
Expand Down Expand Up @@ -181,7 +182,7 @@ final class PromptFeedback {
/// Metadata on the generation request's token usage.
final class UsageMetadata {
/// Constructor
UsageMetadata({
UsageMetadata._({
this.promptTokenCount,
this.candidatesTokenCount,
this.totalTokenCount,
Expand Down Expand Up @@ -544,12 +545,13 @@ final class GenerationConfig {
/// Constructor
GenerationConfig(
{this.candidateCount,
this.stopSequences = const [],
this.stopSequences,
this.maxOutputTokens,
this.temperature,
this.topP,
this.topK,
this.responseMimeType});
this.responseMimeType,
this.responseSchema});

/// Number of generated responses to return.
///
Expand All @@ -561,7 +563,7 @@ final class GenerationConfig {
///
/// If specified, the API will stop at the first appearance of a stop
/// sequence. The stop sequence will not be included as part of the response.
final List<String> stopSequences;
final List<String>? stopSequences;

/// The maximum number of tokens to include in a candidate.
///
Expand Down Expand Up @@ -603,18 +605,28 @@ final class GenerationConfig {
/// - `application/json`: JSON response in the candidates.
final String? responseMimeType;

/// Output response schema of the generated candidate text.
///
/// - Note: This only applies when the [responseMimeType] supports
/// a schema; currently this is limited to `application/json`.
final Schema? responseSchema;

/// Convert to json format
Map<String, Object?> toJson() => {
if (candidateCount case final candidateCount?)
'candidateCount': candidateCount,
if (stopSequences.isNotEmpty) 'stopSequences': stopSequences,
if (stopSequences case final stopSequences?
when stopSequences.isNotEmpty)
'stopSequences': stopSequences,
if (maxOutputTokens case final maxOutputTokens?)
'maxOutputTokens': maxOutputTokens,
if (temperature case final temperature?) 'temperature': temperature,
if (topP case final topP?) 'topP': topP,
if (topK case final topK?) 'topK': topK,
if (responseMimeType case final responseMimeType?)
'responseMimeType': responseMimeType,
if (responseSchema case final responseSchema?)
'responseSchema': responseSchema,
};
}

Expand Down Expand Up @@ -786,7 +798,7 @@ UsageMetadata _parseUsageMetadata(Object jsonObject) {
{'totalTokenCount': final int totalTokenCount} => totalTokenCount,
_ => null,
};
return UsageMetadata(
return UsageMetadata._(
promptTokenCount: promptTokenCount,
candidatesTokenCount: candidatesTokenCount,
totalTokenCount: totalTokenCount);
Expand Down
Loading
Loading