From 8922f9b5941df58cb5fbcd152ddd35bcebed8cc5 Mon Sep 17 00:00:00 2001 From: Cynthia J Date: Thu, 23 Jan 2025 22:41:01 -0800 Subject: [PATCH] Finish up all functionality --- .../example/lib/pages/imagen_page.dart | 20 ++--- .../firebase_vertexai/lib/src/base_model.dart | 5 +- .../lib/src/firebase_vertexai.dart | 4 + .../firebase_vertexai/lib/src/imagen_api.dart | 75 +++++++++++++------ .../lib/src/imagen_content.dart | 47 +++++++----- .../lib/src/imagen_model.dart | 59 +++++++-------- .../firebase_vertexai/lib/src/model.dart | 4 +- 7 files changed, 124 insertions(+), 90 deletions(-) diff --git a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart index 6e2dbb497241..e321976dd288 100644 --- a/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart +++ b/packages/firebase_vertexai/firebase_vertexai/example/lib/pages/imagen_page.dart @@ -41,11 +41,15 @@ class _ImagenPageState extends State { @override void initState() { super.initState(); + var generationConfig = ImagenGenerationConfig( + negativePrompt: 'frog', + numberOfImages: 1, + aspectRatio: ImagenAspectRatio.square1x1, + imageFormat: ImagenFormat.jpeg(compressionQuality: 75), + ); _imagenModel = FirebaseVertexAI.instance.imageModel( modelName: 'imagen-3.0-generate-001', - generationConfig: ImagenGenerationConfig( - imageFormat: ImagenFormat.jpeg(compressionQuality: 75), - ), + generationConfig: generationConfig, safetySettings: ImagenSafetySettings( ImagenSafetyFilterLevel.blockLowAndAbove, ImagenPersonFilterLevel.allowAdult, @@ -133,15 +137,7 @@ class _ImagenPageState extends State { _loading = true; }); - var generationConfig = ImagenGenerationConfig( - negativePrompt: 'frog', - numberOfImages: 1, - aspectRatio: ImagenAspectRatio.square1x1); - - var response = await _imagenModel.generateImages( - prompt, - generationConfig: generationConfig, - ); + var response = await _imagenModel.generateImages(prompt); if (response.images.isNotEmpty) { var imagenImage = response.images[0]; diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart index 6c0a8bf51342..805356d1909b 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/base_model.dart @@ -42,8 +42,6 @@ abstract class BaseModel { required FirebaseApp app, required ApiClient client, }) : _model = normalizeModelName(model), - _app = app, - _location = location, _projectUri = _vertexUri(app, location), _client = client; @@ -51,8 +49,7 @@ abstract class BaseModel { static const _apiVersion = 'v1beta'; final ({String prefix, String name}) _model; - final FirebaseApp _app; - final String _location; + final Uri _projectUri; final ApiClient _client; diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart index 100036be6f90..19d37a947432 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/firebase_vertexai.dart @@ -114,6 +114,10 @@ class FirebaseVertexAI extends FirebasePluginPlatform { ); } + /// Create a [ImagenModel]. + /// + /// The optional [safetySettings] can be used to control and guide the + /// generation. See [ImagenSafetySettings] for details. ImagenModel imageModel( {required String modelName, ImagenGenerationConfig? generationConfig, diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart index 846dc7cd07dc..735eab7217d9 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_api.dart @@ -12,18 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -/// +/// Specifies the level of safety filtering for image generation. +/// If not specified, default will be "block_medium_and_above". enum ImagenSafetyFilterLevel { - /// + /// Strongest filtering level, most strict blocking. blockLowAndAbove('block_low_and_above'), - /// + /// Block some problematic prompts and responses. blockMediumAndAbove('block_medium_and_above'), - /// + /// Reduces the number of requests blocked due to safety filters. + /// May increase objectionable content generated by Imagen. blockOnlyHigh('block_only_high'), - /// + /// Block very few problematic prompts and responses. + /// Access to this feature is restricted. blockNone('block_none'); const ImagenSafetyFilterLevel(this._jsonString); @@ -49,15 +52,16 @@ enum ImagenSafetyFilterLevel { String toString() => name; } -/// +/// Allow generation of people by the model. +/// If not specified, the default value is "allow_adult". enum ImagenPersonFilterLevel { - /// + /// Disallow the inclusion of people or faces in images. blockAll('dont_allow'), - /// + /// Allow generation of adults only. allowAdult('allow_adult'), - /// + /// Allow generation of people of all ages. allowAll('allow_all'); const ImagenPersonFilterLevel(this._jsonString); @@ -82,15 +86,17 @@ enum ImagenPersonFilterLevel { String toString() => name; } +/// A class representing safety settings for image generation. /// +/// It includes a safety filter level and a person filter level. final class ImagenSafetySettings { /// Constructor ImagenSafetySettings(this.safetyFilterLevel, this.personFilterLevel); - /// + /// The safety filter level final ImagenSafetyFilterLevel? safetyFilterLevel; - /// + /// The person filter level final ImagenPersonFilterLevel? personFilterLevel; /// Convert to json format. @@ -102,21 +108,21 @@ final class ImagenSafetySettings { }; } -/// +/// The aspect ratio for the image. The default value is "1:1". enum ImagenAspectRatio { - /// + /// Square (1:1). square1x1('1:1'), - /// + /// Portrait (9:16). portrait9x16('9:16'), - /// + /// Landscape (16:9). landscape16x9('16:9'), - /// + /// Portrait (3:4). portrait3x4('3:4'), - /// + /// Landscape (4:3). landscape4x3('4:3'); const ImagenAspectRatio(this._jsonString); @@ -143,17 +149,31 @@ enum ImagenAspectRatio { String toString() => name; } +/// Configuration options for image generation. final class ImagenGenerationConfig { + /// Constructor ImagenGenerationConfig( - {this.negativePrompt, - this.numberOfImages, + {this.numberOfImages, + this.negativePrompt, this.aspectRatio, this.imageFormat, this.addWatermark}); - final String? negativePrompt; + + /// The number of images to generate. Default is 1. final int? numberOfImages; + + /// A description of what to discourage in the generated images. + final String? negativePrompt; + + /// The aspect ratio for the image. The default value is "1:1". final ImagenAspectRatio? aspectRatio; + + /// The image format of the generated images. final ImagenFormat? imageFormat; + + /// Add an invisible watermark to the generated images. + /// Default value for each imagen model can be found in + /// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#generate_images final bool? addWatermark; /// Convert to json format. @@ -166,15 +186,26 @@ final class ImagenGenerationConfig { }; } +/// Represents the image format and compression quality. final class ImagenFormat { + /// Constructor ImagenFormat(this.mimeType, this.compressionQuality); - ImagenFormat.png() : this("image/png", null); + /// Constructor for png + ImagenFormat.png() : this('image/png', null); + + /// Constructor for jpeg ImagenFormat.jpeg({int? compressionQuality}) - : this("image/jpeg", compressionQuality); + : this('image/jpeg', compressionQuality); + + /// The MIME type of the image format. The default value is "image/png". final String mimeType; + + /// The level of compression if the output type is "image/jpeg". + /// Accepted values are 0 through 100. The default value is 75. final int? compressionQuality; + /// Convert to json format. Map toJson() => { 'mimeType': mimeType, if (compressionQuality != null) diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart index 49bc019ffae7..f959d8863a20 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_content.dart @@ -15,28 +15,32 @@ import 'dart:convert'; import 'dart:typed_data'; import 'error.dart'; -/// +/// Base type of Imagen Image. sealed class ImagenImage { + /// Constructor + ImagenImage({required this.mimeType}); + + /// The MIME type of the image format. final String mimeType; /// Convert the [ImagenImage] content to json format. Object toJson(); - - ImagenImage({required this.mimeType}); } +/// Represents an image stored as a base64-encoded string. final class ImagenInlineImage implements ImagenImage { - /// Data contents in bytes. - final Uint8List bytesBase64Encoded; - - @override - final String mimeType; - + /// Constructor ImagenInlineImage({ required this.bytesBase64Encoded, required this.mimeType, }); + /// The data contents in bytes, encoded as base64. + final Uint8List bytesBase64Encoded; + + @override + final String mimeType; + @override Object toJson() => { 'mimeType': mimeType, @@ -44,17 +48,20 @@ final class ImagenInlineImage implements ImagenImage { }; } +/// Represents an image stored in Google Cloud Storage. final class ImagenGCSImage implements ImagenImage { - @override - final String mimeType; - - final String gcsUri; - + /// Constructor ImagenGCSImage({ required this.gcsUri, required this.mimeType, }); + /// The storage URI of the image. + final String gcsUri; + + @override + final String mimeType; + @override Object toJson() => { 'mimeType': mimeType, @@ -62,15 +69,15 @@ final class ImagenGCSImage implements ImagenImage { }; } +/// Represents the response from an image generation request. final class ImagenGenerationResponse { + /// Constructor ImagenGenerationResponse({ required this.images, this.filteredReason, }); - final List images; - final String? filteredReason; - + /// Factory method to create an ImagenGenerationResponse from a JSON object. factory ImagenGenerationResponse.fromJson(Map json) { final filteredReason = json['filteredReason'] as String?; final imagesJson = json['predictions'] as List; @@ -102,6 +109,12 @@ final class ImagenGenerationResponse { throw ArgumentError('Unsupported ImagenImage type: $T'); } } + + /// A list of generated images. The type of the images depends on the T parameter. + final List images; + + /// If the generation was filtered due to safety reasons, a message explaining the reason. + final String? filteredReason; } /// Parse the json to [ImagenGenerationResponse] diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart index 2eca0c77c8e7..9d65fff160da 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/imagen_model.dart @@ -11,15 +11,15 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -import 'imagen_api.dart'; -import 'imagen_content.dart'; -import 'base_model.dart'; -import 'client.dart'; - import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; +import 'base_model.dart'; +import 'client.dart'; +import 'imagen_api.dart'; +import 'imagen_content.dart'; + /// final class ImagenModel extends BaseModel { ImagenModel._( @@ -45,31 +45,26 @@ final class ImagenModel extends BaseModel { Map _generateImagenRequest( String prompt, { - ImagenGenerationConfig? generationConfig, String? gcsUri, }) { final parameters = {}; if (gcsUri != null) parameters['storageUri'] = gcsUri; - // Merge generation configurations - final mergedConfig = { - ...(_generationConfig?.toJson() ?? {}), - ...(generationConfig?.toJson() ?? {}), - }; - - parameters['sampleCount'] = mergedConfig['numberOfImages'] ?? 1; - if (mergedConfig['aspectRatio'] != null) { - parameters['aspectRatio'] = mergedConfig['aspectRatio']; - } - if (mergedConfig['negativePrompt'] != null) { - parameters['negativePrompt'] = mergedConfig['negativePrompt']; - } - if (mergedConfig['addWatermark'] != null) { - parameters['addWatermark'] = mergedConfig['addWatermark']; - } - if (mergedConfig['outputOption'] != null) { - parameters['outputOption'] = mergedConfig['outputOption']; + parameters['sampleCount'] = _generationConfig?.numberOfImages ?? 1; + if (_generationConfig != null) { + if (_generationConfig.aspectRatio != null) { + parameters['aspectRatio'] = _generationConfig.aspectRatio; + } + if (_generationConfig.negativePrompt != null) { + parameters['negativePrompt'] = _generationConfig.negativePrompt; + } + if (_generationConfig.addWatermark != null) { + parameters['addWatermark'] = _generationConfig.addWatermark; + } + if (_generationConfig.imageFormat != null) { + parameters['outputOption'] = _generationConfig.imageFormat!.toJson(); + } } if (_safetySettings != null) { @@ -91,30 +86,30 @@ final class ImagenModel extends BaseModel { }; } + /// Generates images with format of [ImagenInlineImage] based on the given + /// prompt. Future> generateImages( - String prompt, { - ImagenGenerationConfig? generationConfig, - }) => + String prompt, + ) => makeRequest( Task.predict, _generateImagenRequest( prompt, - generationConfig: generationConfig, ), (jsonObject) => parseImagenGenerationResponse(jsonObject), ); + /// Generates images with format of [ImagenGCSImage] based on the given + /// prompt. Future> generateImagesGCS( String prompt, - String gcsUri, { - ImagenGenerationConfig? generationConfig, - }) => + String gcsUri, + ) => makeRequest( Task.predict, _generateImagenRequest( prompt, - generationConfig: generationConfig, gcsUri: gcsUri, ), (jsonObject) => diff --git a/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart b/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart index 91a050cc2922..ddf0d1e5a37e 100644 --- a/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart +++ b/packages/firebase_vertexai/firebase_vertexai/lib/src/model.dart @@ -19,15 +19,13 @@ import 'dart:async'; import 'package:firebase_app_check/firebase_app_check.dart'; import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; - import 'package:http/http.dart' as http; import 'api.dart'; +import 'base_model.dart'; import 'client.dart'; import 'content.dart'; import 'function_calling.dart'; -import 'vertex_version.dart'; -import 'base_model.dart'; /// A multimodel generative model (like Gemini). ///