From 8eefb6efc733804ba5a27c92e8faa4c414c1fc95 Mon Sep 17 00:00:00 2001 From: Cynthia J Date: Thu, 23 Jan 2025 21:13:54 -0800 Subject: [PATCH] imagen model working with new example layout --- .../firebase_vertexai/example/lib/main.dart | 390 +----------------- .../example/lib/pages/image_prompt_page.dart | 67 ++- ...storage_uri_page.dart => imagen_page.dart} | 114 ++--- .../lib/firebase_vertexai.dart | 2 +- 4 files changed, 130 insertions(+), 443 deletions(-) rename packages/firebase_vertexai/firebase_vertexai/example/lib/pages/{storage_uri_page.dart => imagen_page.dart} (59%) diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart index 120c1b656425..0813c334e14e 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart @@ -22,7 +22,7 @@ import 'pages/function_calling_page.dart'; import 'pages/image_prompt_page.dart'; import 'pages/token_count_page.dart'; import 'pages/schema_page.dart'; -import 'pages/storage_uri_page.dart'; +import 'pages/imagen_page.dart'; // REQUIRED if you want to run on Web const FirebaseOptions? options = null; @@ -30,7 +30,7 @@ const FirebaseOptions? options = null; void main() async { WidgetsFlutterBinding.ensureInitialized(); await Firebase.initializeApp(); - await FirebaseAuth.instance.signInAnonymously(); + //await FirebaseAuth.instance.signInAnonymously(); var vertex_instance = FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance); @@ -79,7 +79,7 @@ class _HomeScreenState extends State { title: 'Function Calling', ), // function calling will initial its own model ImagePromptPage(title: 'Image Prompt', model: widget.model), - StorageUriPromptPage(title: 'Storage URI Prompt', model: widget.model), + ImagenPage(title: 'Imagen Model', model: widget.model), SchemaPromptPage(title: 'Schema Prompt', model: widget.model), ]; @@ -134,11 +134,11 @@ class _HomeScreenState extends State { ), BottomNavigationBarItem( icon: Icon( - Icons.folder, + Icons.image_search, color: Theme.of(context).colorScheme.primary, ), - label: 'Storage URI Prompt', - tooltip: 'Storage URI Prompt', + label: 'Imagen Model', + tooltip: 'Imagen Model', ), BottomNavigationBarItem( icon: Icon( @@ -154,382 +154,4 @@ class _HomeScreenState extends State { ), ); } - - Future _promptSchemaTest(String subject) async { - setState(() { - _loading = true; - }); - try { - final content = [ - Content.text( - "For use in a children's card game, generate 10 animal-based " - 'characters.', - ), - ]; - - final jsonSchema = Schema.object( - properties: { - 'characters': Schema.array( - items: Schema.object( - properties: { - 'name': Schema.string(), - 'age': Schema.integer(), - 'species': Schema.string(), - 'accessory': - Schema.enumString(enumValues: ['hat', 'belt', 'shoes']), - }, - ), - ), - }, - optionalProperties: ['accessory'], - ); - - final response = await _model.generateContent( - content, - generationConfig: GenerationConfig( - responseMimeType: 'application/json', - responseSchema: jsonSchema, - ), - ); - - var text = response.text; - _generatedContent.add((image: null, text: text, fromUser: false)); - - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } - } - - Future _sendStorageUriPrompt(String message) async { - setState(() { - _loading = true; - }); - try { - final content = [ - Content.multi([ - TextPart(message), - FileData( - 'image/jpeg', - 'gs://vertex-ai-example-ef5a2.appspot.com/foodpic.jpg', - ), - ]), - ]; - _generatedContent.add((image: null, text: message, fromUser: true)); - - var response = await _model.generateContent(content); - var text = response.text; - _generatedContent.add((image: null, text: text, fromUser: false)); - - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } - } - - Future _sendImagePrompt(String message) async { - setState(() { - _loading = true; - }); - try { - ByteData catBytes = await rootBundle.load('assets/images/cat.jpg'); - ByteData sconeBytes = await rootBundle.load('assets/images/scones.jpg'); - final content = [ - Content.multi([ - TextPart(message), - // The only accepted mime types are image/*. - InlineDataPart('image/jpeg', catBytes.buffer.asUint8List()), - InlineDataPart('image/jpeg', sconeBytes.buffer.asUint8List()), - ]), - ]; - _generatedContent.add( - ( - image: Image.asset('assets/images/cat.jpg'), - text: message, - fromUser: true - ), - ); - _generatedContent.add( - ( - image: Image.asset('assets/images/scones.jpg'), - text: null, - fromUser: true - ), - ); - - var response = await _model.generateContent(content); - var text = response.text; - _generatedContent.add((image: null, text: text, fromUser: false)); - - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } - } - - Future _sendChatMessage(String message) async { - setState(() { - _loading = true; - }); - - try { - _generatedContent.add((image: null, text: message, fromUser: true)); - var response = await _chat?.sendMessage( - Content.text(message), - ); - var text = response?.text; - _generatedContent.add((image: null, text: text, fromUser: false)); - - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); - } - } - - Future _testFunctionCalling() async { - setState(() { - _loading = true; - }); - final functionCallChat = _functionCallModel.startChat(); - const prompt = 'What is the weather like in Boston on 10/02 this year?'; - - // Send the message to the generative model. - var response = await functionCallChat.sendMessage( - Content.text(prompt), - ); - - final functionCalls = response.functionCalls.toList(); - // When the model response with a function call, invoke the function. - if (functionCalls.isNotEmpty) { - final functionCall = functionCalls.first; - if (functionCall.name == 'fetchWeather') { - Map location = - functionCall.args['location']! as Map; - var date = functionCall.args['date']! as String; - var city = location['city'] as String; - var state = location['state'] as String; - final functionResult = await fetchWeather(Location(city, state), date); - // Send the response to the model so that it can use the result to - // generate text for the user. - response = await functionCallChat.sendMessage( - Content.functionResponse(functionCall.name, functionResult), - ); - } else { - throw UnimplementedError( - 'Function not declared to the model: ${functionCall.name}', - ); - } - } - // When the model responds with non-null text content, print it. - if (response.text case final text?) { - _generatedContent.add((image: null, text: text, fromUser: false)); - setState(() { - _loading = false; - }); - } - } - - Future _testCountToken() async { - setState(() { - _loading = true; - }); - - const prompt = 'tell a short story'; - var content = Content.text(prompt); - var tokenResponse = await _model.countTokens([content]); - final tokenResult = 'Count token: ${tokenResponse.totalTokens}, billable ' - 'characters: ${tokenResponse.totalBillableCharacters}'; - _generatedContent.add((image: null, text: tokenResult, fromUser: false)); - - var contentResponse = await _model.generateContent([content]); - final contentMetaData = 'result metadata, promptTokenCount:' - '${contentResponse.usageMetadata!.promptTokenCount}, ' - 'candidatesTokenCount:' - '${contentResponse.usageMetadata!.candidatesTokenCount}, ' - 'totalTokenCount:' - '${contentResponse.usageMetadata!.totalTokenCount}'; - _generatedContent - .add((image: null, text: contentMetaData, fromUser: false)); - setState(() { - _loading = false; - }); - } - - Future _testImagen(String prompt) async { - setState(() { - _loading = true; - }); - var model = FirebaseVertexAI.instance.imageModel( - modelName: 'imagen-3.0-generate-001', - generationConfig: ImagenGenerationConfig( - imageFormat: ImagenFormat.jpeg(compressionQuality: 75)), - safetySettings: ImagenSafetySettings( - ImagenSafetyFilterLevel.blockLowAndAbove, - ImagenPersonFilterLevel.allowAdult, - ), - ); - - var generationConfig = ImagenGenerationConfig( - negativePrompt: 'frog', - numberOfImages: 1, - aspectRatio: ImagenAspectRatio.square1x1); - - var response = await model.generateImages( - prompt, - generationConfig: generationConfig, - ); - - if (response.images.isNotEmpty) { - var imagenImage = response.images[0]; - // Process the image - _generatedContent.add( - ( - image: Image.memory(imagenImage.bytesBase64Encoded), - text: prompt, - fromUser: false - ), - ); - } else { - // Handle the case where no images were generated - print('Error: No images were generated.'); - } - setState(() { - _loading = false; - }); - } - - void _showError(String message) { - showDialog( - context: context, - builder: (context) { - return AlertDialog( - title: const Text('Something went wrong'), - content: SingleChildScrollView( - child: SelectableText(message), - ), - actions: [ - TextButton( - onPressed: () { - Navigator.of(context).pop(); - }, - child: const Text('OK'), - ), - ], - ); - }, - ); - } -} - -class MessageWidget extends StatelessWidget { - final Image? image; - final String? text; - final bool isFromUser; - - const MessageWidget({ - super.key, - this.image, - this.text, - required this.isFromUser, - }); - - @override - Widget build(BuildContext context) { - return Row( - mainAxisAlignment: - isFromUser ? MainAxisAlignment.end : MainAxisAlignment.start, - children: [ - Flexible( - child: Container( - constraints: const BoxConstraints(maxWidth: 600), - decoration: BoxDecoration( - color: isFromUser - ? Theme.of(context).colorScheme.primaryContainer - : Theme.of(context).colorScheme.surfaceContainerHighest, - borderRadius: BorderRadius.circular(18), - ), - padding: const EdgeInsets.symmetric( - vertical: 15, - horizontal: 20, - ), - margin: const EdgeInsets.only(bottom: 8), - child: Column( - children: [ - if (text case final text?) MarkdownBody(data: text), - if (image case final image?) image, - ], - ), - ), - ), - ], - ); - } } diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart index f8b111296287..1a32cd370d84 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/image_prompt_page.dart @@ -89,13 +89,23 @@ class _ImagePromptPageState extends State { const SizedBox.square( dimension: 15, ), - ElevatedButton( - onPressed: !_loading - ? () async { - await _sendImagePrompt(_textController.text); - } - : null, - child: const Text('Send Image Prompt'), + IconButton( + onPressed: () async { + await _sendImagePrompt(_textController.text); + }, + icon: Icon( + Icons.image, + color: Theme.of(context).colorScheme.primary, + ), + ), + IconButton( + onPressed: () async { + await _sendStorageUriPrompt(_textController.text); + }, + icon: Icon( + Icons.storage, + color: Theme.of(context).colorScheme.primary, + ), ), ], ), @@ -162,6 +172,49 @@ class _ImagePromptPageState extends State { } } + Future _sendStorageUriPrompt(String message) async { + setState(() { + _loading = true; + }); + try { + final content = [ + Content.multi([ + TextPart(message), + FileData( + 'image/jpeg', + 'gs://vertex-ai-example-ef5a2.appspot.com/foodpic.jpg', + ), + ]), + ]; + _generatedContent.add(MessageData(text: message, fromUser: true)); + + var response = await widget.model.generateContent(content); + var text = response.text; + _generatedContent.add(MessageData(text: text, fromUser: false)); + + if (text == null) { + _showError('No response from API.'); + return; + } else { + setState(() { + _loading = false; + _scrollDown(); + }); + } + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + void _showError(String message) { showDialog( context: context, diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart similarity index 59% rename from packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart rename to packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart index b1624d52efc6..6e2dbb497241 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/storage_uri_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -16,8 +16,8 @@ import 'package:flutter/material.dart'; import 'package:firebase_vertexai/firebase_vertexai.dart'; import '../widgets/message_widget.dart'; -class StorageUriPromptPage extends StatefulWidget { - const StorageUriPromptPage({ +class ImagenPage extends StatefulWidget { + const ImagenPage({ super.key, required this.title, required this.model, @@ -27,15 +27,31 @@ class StorageUriPromptPage extends StatefulWidget { final GenerativeModel model; @override - State createState() => _StorageUriPromptPageState(); + State createState() => _ImagenPageState(); } -class _StorageUriPromptPageState extends State { +class _ImagenPageState extends State { final ScrollController _scrollController = ScrollController(); final TextEditingController _textController = TextEditingController(); final FocusNode _textFieldFocus = FocusNode(); - final List _messages = []; + final List _generatedContent = []; bool _loading = false; + late final ImagenModel _imagenModel; + + @override + void initState() { + super.initState(); + _imagenModel = FirebaseVertexAI.instance.imageModel( + modelName: 'imagen-3.0-generate-001', + generationConfig: ImagenGenerationConfig( + imageFormat: ImagenFormat.jpeg(compressionQuality: 75), + ), + safetySettings: ImagenSafetySettings( + ImagenSafetyFilterLevel.blockLowAndAbove, + ImagenPersonFilterLevel.allowAdult, + ), + ); + } void _scrollDown() { WidgetsBinding.instance.addPostFrameCallback( @@ -66,11 +82,12 @@ class _StorageUriPromptPageState extends State { controller: _scrollController, itemBuilder: (context, idx) { return MessageWidget( - text: _messages[idx].text, - isFromUser: _messages[idx].fromUser ?? false, + text: _generatedContent[idx].text, + image: _generatedContent[idx].image, + isFromUser: _generatedContent[idx].fromUser ?? false, ); }, - itemCount: _messages.length, + itemCount: _generatedContent.length, ), ), Padding( @@ -90,14 +107,18 @@ class _StorageUriPromptPageState extends State { const SizedBox.square( dimension: 15, ), - ElevatedButton( - onPressed: !_loading - ? () async { - await _sendStorageUriPrompt(_textController.text); - } - : null, - child: const Text('Send Storage URI Prompt'), - ), + if (!_loading) + IconButton( + onPressed: () async { + await _testImagen(_textController.text); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + ) + else + const CircularProgressIndicator(), ], ), ), @@ -107,47 +128,38 @@ class _StorageUriPromptPageState extends State { ); } - Future _sendStorageUriPrompt(String message) async { + Future _testImagen(String prompt) async { setState(() { _loading = true; }); - try { - final content = [ - Content.multi([ - TextPart(message), - FileData( - 'image/jpeg', - 'gs://vertex-ai-example-ef5a2.appspot.com/foodpic.jpg', - ), - ]), - ]; - _messages.add(MessageData(text: message, fromUser: true)); - var response = await widget.model.generateContent(content); - var text = response.text; - _messages.add(MessageData(text: text, fromUser: false)); + var generationConfig = ImagenGenerationConfig( + negativePrompt: 'frog', + numberOfImages: 1, + aspectRatio: ImagenAspectRatio.square1x1); + + var response = await _imagenModel.generateImages( + prompt, + generationConfig: generationConfig, + ); - if (text == null) { - _showError('No response from API.'); - return; - } else { - setState(() { - _loading = false; - _scrollDown(); - }); - } - } catch (e) { - _showError(e.toString()); - setState(() { - _loading = false; - }); - } finally { - _textController.clear(); - setState(() { - _loading = false; - }); - _textFieldFocus.requestFocus(); + if (response.images.isNotEmpty) { + var imagenImage = response.images[0]; + // Process the image + _generatedContent.add( + MessageData( + image: Image.memory(imagenImage.bytesBase64Encoded), + text: prompt, + fromUser: false, + ), + ); + } else { + // Handle the case where no images were generated + _showError('Error: No images were generated.'); } + setState(() { + _loading = false; + }); } void _showError(String message) { diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart index 19bf7ed37043..af79eed8efd1 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart @@ -58,7 +58,6 @@ export 'src/function_calling.dart' ToolConfig; export 'src/imagen_api.dart' show - ImagenModelConfig, ImagenSafetySettings, ImagenFormat, ImagenSafetyFilterLevel, @@ -67,5 +66,6 @@ export 'src/imagen_api.dart' ImagenAspectRatio; export 'src/imagen_content.dart' show ImagenInlineImage, ImagenGCSImage, ImagenImage; +export 'src/imagen_model.dart' show ImagenModel; export 'src/model.dart' show GenerativeModel; export 'src/schema.dart' show Schema, SchemaType;