diff --git a/packages/firebase_ai/firebase_ai/example/lib/main.dart b/packages/firebase_ai/firebase_ai/example/lib/main.dart index db1344210deb..ed9748965723 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/main.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/main.dart @@ -31,6 +31,7 @@ import 'pages/json_schema_page.dart'; import 'pages/schema_page.dart'; import 'pages/token_count_page.dart'; import 'pages/video_page.dart'; +import 'pages/server_template_page.dart'; void main() async { WidgetsFlutterBinding.ensureInitialized(); @@ -199,6 +200,11 @@ class _HomeScreenState extends State { model: currentModel, useVertexBackend: useVertexBackend, ); + case 11: + return ServerTemplatePage( + title: 'Server Template', + useVertexBackend: useVertexBackend, + ); default: // Fallback to the first page in case of an unexpected index @@ -227,18 +233,15 @@ class _HomeScreenState extends State { style: TextStyle( fontSize: 12, color: widget.useVertexBackend - ? Theme.of(context) - .colorScheme - .onSurface - .withValues(alpha: 0.7) + ? Theme.of(context).colorScheme.onSurface.withAlpha(180) : Theme.of(context).colorScheme.primary, ), ), Switch( value: widget.useVertexBackend, onChanged: widget.onBackendChanged, - activeTrackColor: Colors.green.withValues(alpha: 0.5), - inactiveTrackColor: Colors.blueGrey.withValues(alpha: 0.5), + activeTrackColor: Colors.green.withAlpha(128), + inactiveTrackColor: Colors.blueGrey.withAlpha(128), activeThumbColor: Colors.green, inactiveThumbColor: Colors.blueGrey, ), @@ -251,7 +254,7 @@ class _HomeScreenState extends State { : Theme.of(context) .colorScheme .onSurface - .withValues(alpha: 0.7), + .withAlpha(180), ), ), ], @@ -273,7 +276,7 @@ class _HomeScreenState extends State { unselectedFontSize: 9, selectedItemColor: Theme.of(context).colorScheme.primary, unselectedItemColor: widget.useVertexBackend - ? Theme.of(context).colorScheme.onSurface.withValues(alpha: 0.7) + ? Theme.of(context).colorScheme.onSurface.withAlpha(180) : Colors.grey, items: const [ BottomNavigationBarItem( @@ -333,6 +336,13 @@ class _HomeScreenState extends State { label: 'Live', tooltip: 'Live Stream', ), + BottomNavigationBarItem( + icon: Icon( + Icons.storage, + ), + label: 'Server', + tooltip: 'Server Template', + ), ], currentIndex: widget.selectedIndex, onTap: _onItemTapped, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart index 48fc8667af59..5c5009ca3158 100644 --- a/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/image_prompt_page.dart @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - import 'package:flutter/material.dart'; import 'package:firebase_ai/firebase_ai.dart'; import 'package:flutter/services.dart'; @@ -65,11 +64,13 @@ class _ImagePromptPageState extends State { var content = _generatedContent[idx]; return MessageWidget( text: content.text, - image: Image.memory( - content.imageBytes!, - cacheWidth: 400, - cacheHeight: 400, - ), + image: content.imageBytes == null + ? null + : Image.memory( + content.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ), isFromUser: content.fromUser ?? false, ); }, diff --git a/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart new file mode 100644 index 000000000000..44a226f9a2db --- /dev/null +++ b/packages/firebase_ai/firebase_ai/example/lib/pages/server_template_page.dart @@ -0,0 +1,438 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:convert'; +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import '../widgets/message_widget.dart'; +import 'package:firebase_ai/firebase_ai.dart'; + +class ServerTemplatePage extends StatefulWidget { + const ServerTemplatePage({ + super.key, + required this.title, + required this.useVertexBackend, + }); + + final String title; + final bool useVertexBackend; + + @override + State createState() => _ServerTemplatePageState(); +} + +class _ServerTemplatePageState extends State { + final ScrollController _scrollController = ScrollController(); + final TextEditingController _textController = TextEditingController(); + final FocusNode _textFieldFocus = FocusNode(); + final List _messages = []; + bool _loading = false; + + TemplateGenerativeModel? _templateGenerativeModel; + TemplateChatSession? _chatSession; + TemplateImagenModel? _templateImagenModel; + TemplateChatSession? _chatFunctionSession; + + @override + void initState() { + super.initState(); + _initializeServerTemplate(); + } + + void _initializeServerTemplate() { + if (widget.useVertexBackend) { + _templateGenerativeModel = + FirebaseAI.vertexAI(location: 'global').templateGenerativeModel(); + _templateImagenModel = + FirebaseAI.vertexAI(location: 'global').templateImagenModel(); + } else { + _templateGenerativeModel = + FirebaseAI.googleAI().templateGenerativeModel(); + _templateImagenModel = FirebaseAI.googleAI().templateImagenModel(); + } + _chatSession = _templateGenerativeModel?.startChat('chat_history.prompt'); + _chatFunctionSession = + _templateGenerativeModel?.startChat('function-calling'); + } + + void _scrollDown() { + WidgetsBinding.instance.addPostFrameCallback( + (_) => _scrollController.animateTo( + _scrollController.position.maxScrollExtent, + duration: const Duration( + milliseconds: 750, + ), + curve: Curves.easeOutCirc, + ), + ); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: Text(widget.title), + ), + body: Padding( + padding: const EdgeInsets.all(8), + child: Column( + mainAxisAlignment: MainAxisAlignment.center, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Expanded( + child: ListView.builder( + controller: _scrollController, + itemBuilder: (context, idx) { + final message = _messages[idx]; + return MessageWidget( + text: message.text, + image: message.imageBytes != null + ? Image.memory( + message.imageBytes!, + cacheWidth: 400, + cacheHeight: 400, + ) + : null, + isFromUser: message.fromUser ?? false, + ); + }, + itemCount: _messages.length, + ), + ), + Padding( + padding: const EdgeInsets.symmetric( + vertical: 25, + horizontal: 15, + ), + child: Row( + children: [ + Expanded( + child: TextField( + autofocus: true, + focusNode: _textFieldFocus, + controller: _textController, + onSubmitted: _sendServerTemplateMessage, + ), + ), + const SizedBox.square( + dimension: 15, + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateImagen(_textController.text); + }, + icon: Icon( + Icons.image_search, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Imagen', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateFunctionCall(_textController.text); + }, + icon: Icon( + Icons.functions, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Function Calling', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateImageInput(_textController.text); + }, + icon: Icon( + Icons.image, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Image Input', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _serverTemplateChat(_textController.text); + }, + icon: Icon( + Icons.chat, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Chat', + ), + if (!_loading) + IconButton( + onPressed: () async { + await _sendServerTemplateMessage(_textController.text); + }, + icon: Icon( + Icons.send, + color: Theme.of(context).colorScheme.primary, + ), + tooltip: 'Generate', + ) + else + const CircularProgressIndicator(), + ], + ), + ), + ], + ), + ), + ); + } + + Future _serverTemplateFunctionCall(String message) async { + setState(() { + _loading = true; + }); + + try { + _messages.add( + MessageData(text: message, fromUser: true), + ); + var response = await _chatFunctionSession?.sendMessage( + Content.text(message), + inputs: { + 'customerName': message, + 'orientation': 'PORTRAIT', + 'useFlash': true, + 'zoom': 2, + }, + ); + + _messages.add(MessageData(text: response?.text, fromUser: false)); + + final functionCalls = response?.functionCalls.toList(); + if (functionCalls!.isNotEmpty) { + final functionCall = functionCalls.first; + if (functionCall.name == 'takePicture') { + ByteData catBytes = await rootBundle.load('assets/images/cat.jpg'); + var imageBytes = catBytes.buffer.asUint8List(); + final functionResult = { + 'aspectRatio': '16:9', + 'mimeType': 'image/jpeg', + 'data': base64Encode(imageBytes), + }; + var functionResponse = await _chatFunctionSession?.sendMessage( + Content.functionResponse(functionCall.name, functionResult), + inputs: {}, + ); + _messages + .add(MessageData(text: functionResponse?.text, fromUser: false)); + } + } + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _serverTemplateImagen(String message) async { + setState(() { + _loading = true; + }); + MessageData? resultMessage; + try { + _messages.add(MessageData(text: message, fromUser: true)); + var response = await _templateImagenModel?.generateImages( + 'portrait-googleai', + inputs: { + 'animal': message, + }, + ); + + if (response!.images.isNotEmpty) { + var imagenImage = response.images[0]; + + resultMessage = MessageData( + imageBytes: imagenImage.bytesBase64Encoded, + text: message, + fromUser: false, + ); + } else { + // Handle the case where no images were generated + _showError('Error: No images were generated.'); + } + + setState(() { + if (resultMessage != null) { + _messages.add(resultMessage); + } + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _serverTemplateImageInput(String message) async { + setState(() { + _loading = true; + }); + + try { + ByteData catBytes = await rootBundle.load('assets/images/cat.jpg'); + var imageBytes = catBytes.buffer.asUint8List(); + _messages.add( + MessageData( + text: message, + imageBytes: imageBytes, + fromUser: true, + ), + ); + + var response = await _templateGenerativeModel?.generateContent( + 'media.prompt', + inputs: { + 'imageData': { + 'isInline': true, + 'mimeType': 'image/jpeg', + 'contents': base64Encode(imageBytes), + }, + }, + ); + _messages.add(MessageData(text: response?.text, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _serverTemplateChat(String message) async { + setState(() { + _loading = true; + }); + + try { + _messages.add( + MessageData(text: message, fromUser: true), + ); + var response = await _chatSession?.sendMessage( + Content.text(message), + inputs: { + 'message': message, + }, + ); + + var text = response?.text; + + _messages.add(MessageData(text: text, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + Future _sendServerTemplateMessage(String message) async { + setState(() { + _loading = true; + }); + + try { + var response = await _templateGenerativeModel?.generateContent( + 'new-greeting', + ); + + _messages.add(MessageData(text: response?.text, fromUser: false)); + + setState(() { + _loading = false; + _scrollDown(); + }); + } catch (e) { + _showError(e.toString()); + setState(() { + _loading = false; + }); + } finally { + _textController.clear(); + setState(() { + _loading = false; + }); + _textFieldFocus.requestFocus(); + } + } + + void _showError(String message) { + showDialog( + context: context, + builder: (context) { + return AlertDialog( + title: const Text('Something went wrong'), + content: SingleChildScrollView( + child: SelectableText(message), + ), + actions: [ + TextButton( + onPressed: () { + Navigator.of(context).pop(); + }, + child: const Text('OK'), + ), + ], + ); + }, + ); + } +} diff --git a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart index b42c458d3075..629b642fe740 100644 --- a/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/firebase_ai.dart @@ -33,7 +33,12 @@ export 'src/api.dart' SafetySetting, UsageMetadata; export 'src/base_model.dart' - show GenerativeModel, ImagenModel, LiveGenerativeModel; + show + GenerativeModel, + ImagenModel, + LiveGenerativeModel, + TemplateGenerativeModel, + TemplateImagenModel; export 'src/chat.dart' show ChatSession, StartChatExtension; export 'src/content.dart' show @@ -100,6 +105,9 @@ export 'src/live_api.dart' LiveServerResponse; export 'src/live_session.dart' show LiveSession; export 'src/schema.dart' show Schema, SchemaType; +export 'src/server_template/template_chat.dart' + show TemplateChatSession, StartTemplateChatExtension; + export 'src/tool.dart' show FunctionCallingConfig, diff --git a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart index 5cb2f83d5ed0..1ee1cc48864c 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/base_model.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/base_model.dart @@ -41,6 +41,8 @@ import 'tool.dart'; part 'generative_model.dart'; part 'imagen/imagen_model.dart'; part 'live_model.dart'; +part 'server_template/template_generative_model.dart'; +part 'server_template/template_imagen_model.dart'; /// [Task] enum class for [GenerativeModel] to make request. enum Task { @@ -57,6 +59,18 @@ enum Task { predict, } +/// [TemplateTask] enum class for [TemplateGenerativeModel] to make request. +enum TemplateTask { + /// Request type for server template generate content. + templateGenerateContent, + + /// Request type for server template stream generate content + templateStreamGenerateContent, + + /// Request type for server template for Prediction Services like Imagen. + templatePredict, +} + abstract interface class _ModelUri { String get baseAuthority; String get apiVersion; @@ -94,6 +108,7 @@ final class _VertexUri implements _ModelUri { } final Uri _projectUri; + @override final ({String prefix, String name}) model; @@ -130,10 +145,12 @@ final class _GoogleAIUri implements _ModelUri { static const _apiVersion = 'v1beta'; static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static Uri _googleAIBaseUri( {String apiVersion = _apiVersion, required FirebaseApp app}) => Uri.https( _baseAuthority, '$apiVersion/projects/${app.options.projectId}'); + final Uri _baseUri; @override @@ -151,6 +168,92 @@ final class _GoogleAIUri implements _ModelUri { .followedBy([model.prefix, '${model.name}:${task.name}'])); } +abstract interface class _TemplateUri { + String get baseAuthority; + String get apiVersion; + Uri templateTaskUri(TemplateTask task, String templateId); + String templateName(String templateId); +} + +final class _TemplateVertexUri implements _TemplateUri { + _TemplateVertexUri({required String location, required FirebaseApp app}) + : _templateUri = _vertexTemplateUri(app, location), + _templateName = _vertexTemplateName(app, location); + + static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static const _apiVersion = 'v1beta'; + + final Uri _templateUri; + final String _templateName; + + static Uri _vertexTemplateUri(FirebaseApp app, String location) { + var projectId = app.options.projectId; + return Uri.https( + _baseAuthority, + '/$_apiVersion/projects/$projectId/locations/$location', + ); + } + + static String _vertexTemplateName(FirebaseApp app, String location) { + var projectId = app.options.projectId; + return 'projects/$projectId/locations/$location'; + } + + @override + String get baseAuthority => _baseAuthority; + + @override + String get apiVersion => _apiVersion; + + @override + Uri templateTaskUri(TemplateTask task, String templateId) { + return _templateUri.replace( + pathSegments: _templateUri.pathSegments + .followedBy(['templates', '$templateId:${task.name}'])); + } + + @override + String templateName(String templateId) => + '$_templateName/templates/$templateId'; +} + +final class _TemplateGoogleAIUri implements _TemplateUri { + _TemplateGoogleAIUri({ + required FirebaseApp app, + }) : _templateUri = _googleAITemplateUri(app: app), + _templateName = _googleAITemplateName(app: app); + + static const _baseAuthority = 'firebasevertexai.googleapis.com'; + static const _apiVersion = 'v1beta'; + final Uri _templateUri; + final String _templateName; + + static Uri _googleAITemplateUri( + {String apiVersion = _apiVersion, required FirebaseApp app}) => + Uri.https( + _baseAuthority, '$apiVersion/projects/${app.options.projectId}'); + + static String _googleAITemplateName({required FirebaseApp app}) => + 'projects/${app.options.projectId}'; + + @override + String get baseAuthority => _baseAuthority; + + @override + String get apiVersion => _apiVersion; + + @override + Uri templateTaskUri(TemplateTask task, String templateId) { + return _templateUri.replace( + pathSegments: _templateUri.pathSegments + .followedBy(['templates', '$templateId:${task.name}'])); + } + + @override + String templateName(String templateId) => + '$_templateName/templates/$templateId'; +} + /// Base class for models. /// /// Do not instantiate directly. @@ -231,3 +334,60 @@ abstract class BaseApiClientModel extends BaseModel { T Function(Map) parse) => _client.makeRequest(taskUri(task), params).then(parse); } + +abstract class BaseTemplateApiClientModel extends BaseApiClientModel { + BaseTemplateApiClientModel( + {required super.serializationStrategy, + required super.modelUri, + required super.client, + required _TemplateUri templateUri}) + : _templateUri = templateUri; + + final _TemplateUri _templateUri; + + /// Make a unary request for [task] with [templateId] and JSON encodable + /// [inputs]. + Future makeTemplateRequest( + TemplateTask task, + String templateId, + Map? inputs, + Iterable? history, + T Function(Map) parse) { + Map body = {}; + if (inputs != null) { + body['inputs'] = inputs; + } + if (history != null) { + body['history'] = history.map((c) => c.toJson()).toList(); + } + return _client + .makeRequest(templateTaskUri(task, templateId), body) + .then(parse); + } + + /// Make a unary request for [task] with [templateId] and JSON encodable + /// [inputs]. + Stream streamTemplateRequest( + TemplateTask task, + String templateId, + Map? inputs, + Iterable? history, + T Function(Map) parse) { + Map body = {}; + if (inputs != null) { + body['inputs'] = inputs; + } + if (history != null) { + body['history'] = history.map((c) => c.toJson()).toList(); + } + final response = + _client.streamRequest(templateTaskUri(task, templateId), body); + return response.map(parse); + } + + Uri templateTaskUri(TemplateTask task, String templateId) => + _templateUri.templateTaskUri(task, templateId); + + String templateName(String templateId) => + _templateUri.templateName(templateId); +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/chat.dart b/packages/firebase_ai/firebase_ai/lib/src/chat.dart index fdcbd3cb2920..6f6d32d6f6f0 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/chat.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/chat.dart @@ -17,6 +17,7 @@ import 'dart:async'; import 'api.dart'; import 'base_model.dart'; import 'content.dart'; +import 'utils/chat_utils.dart'; import 'utils/mutex.dart'; /// A back-and-forth chat with a generative model. @@ -24,8 +25,7 @@ import 'utils/mutex.dart'; /// Records messages sent and received in [history]. The history will always /// record the content from the first candidate in the /// [GenerateContentResponse], other candidates may be available on the returned -/// response. The history is maintained and updated by the `google_generative_ai` -/// package and reflects the most current state of the chat session. +/// response. The history reflects the most current state of the chat session. final class ChatSession { ChatSession._(this._generateContent, this._generateContentStream, this._history, this._safetySettings, this._generationConfig); @@ -114,7 +114,7 @@ final class ChatSession { } if (content.isNotEmpty) { _history.add(message); - _history.add(_aggregate(content)); + _history.add(historyAggregate(content)); } } catch (e, s) { controller.addError(e, s); @@ -124,46 +124,6 @@ final class ChatSession { }); return controller.stream; } - - /// Aggregates a list of [Content] responses into a single [Content]. - /// - /// Includes all the [Content.parts] of every element of [contents], - /// and concatenates adjacent [TextPart]s into a single [TextPart], - /// even across adjacent [Content]s. - Content _aggregate(List contents) { - assert(contents.isNotEmpty); - final role = contents.first.role ?? 'model'; - final textBuffer = StringBuffer(); - // If non-null, only a single text part has been seen. - TextPart? previousText; - final parts = []; - void addBufferedText() { - if (textBuffer.isEmpty) return; - if (previousText case final singleText?) { - parts.add(singleText); - previousText = null; - } else { - parts.add(TextPart(textBuffer.toString())); - } - textBuffer.clear(); - } - - for (final content in contents) { - for (final part in content.parts) { - if (part case TextPart(:final text)) { - if (text.isNotEmpty) { - previousText = textBuffer.isEmpty ? part : null; - textBuffer.write(text); - } - } else { - addBufferedText(); - parts.add(part); - } - } - } - addBufferedText(); - return Content(role, parts); - } } /// [StartChatExtension] on [GenerativeModel] diff --git a/packages/firebase_ai/firebase_ai/lib/src/client.dart b/packages/firebase_ai/firebase_ai/lib/src/client.dart index 5befcf695f22..df77752e535c 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/client.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/client.dart @@ -64,6 +64,7 @@ final class HttpApiClient implements ApiClient { Future> makeRequest( Uri uri, Map body) async { final headers = await _headers(); + print('uri: $uri \nbody: $body \nheaders: $headers'); final response = await (_httpClient?.post ?? http.post)( uri, headers: headers, diff --git a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart index c0379662fdf8..8195aea0b91c 100644 --- a/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart +++ b/packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart @@ -17,6 +17,7 @@ import 'package:firebase_auth/firebase_auth.dart'; import 'package:firebase_core/firebase_core.dart'; import 'package:firebase_core_platform_interface/firebase_core_platform_interface.dart' show FirebasePluginPlatform; +import 'package:meta/meta.dart'; import '../firebase_ai.dart'; import 'base_model.dart'; @@ -196,4 +197,27 @@ class FirebaseAI extends FirebasePluginPlatform { useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, ); } + + @experimental + TemplateGenerativeModel templateGenerativeModel() { + return createTemplateGenerativeModel( + app: app, + location: location, + useVertexBackend: _useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + appCheck: appCheck); + } + + @experimental + TemplateImagenModel templateImagenModel() { + return createTemplateImagenModel( + app: app, + location: location, + useVertexBackend: _useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + appCheck: appCheck, + ); + } } diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart new file mode 100644 index 000000000000..74e7221a339d --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_chat.dart @@ -0,0 +1,144 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import 'dart:async'; + +import '../api.dart'; +import '../base_model.dart'; +import '../content.dart'; +import '../utils/chat_utils.dart'; +import '../utils/mutex.dart'; + +/// A back-and-forth chat with a server template. +/// +/// Records messages sent and received in [history]. The history will always +/// record the content from the first candidate in the +/// [GenerateContentResponse], other candidates may be available on the returned +/// response. The history reflects the most current state of the chat session. +final class TemplateChatSession { + TemplateChatSession._( + this._templateHistoryGenerateContent, + this._templateHistoryGenerateContentStream, + this._templateId, + this._history, + ); + + final Future Function( + Iterable content, String templateId, + {Map? inputs}) _templateHistoryGenerateContent; + + final Stream Function( + Iterable content, String templateId, + {Map? inputs}) _templateHistoryGenerateContentStream; + final String _templateId; + final List _history; + + final _mutex = Mutex(); + + /// The content that has been successfully sent to, or received from, the + /// generative model. + /// + /// If there are outstanding requests from calls to [sendMessage], + /// these will not be reflected in the history. + /// Messages without a candidate in the response are not recorded in history, + /// including the message sent to the model. + Iterable get history => _history.skip(0); + + /// Sends [inputs] to the server template as a continuation of the chat [history]. + /// + /// Prepends the history to the request and uses the provided model to + /// generate new content. + /// + /// When there are no candidates in the response, the [message] and response + /// are ignored and will not be recorded in the [history]. + Future sendMessage(Content message, + {Map? inputs}) async { + final lock = await _mutex.acquire(); + try { + final response = await _templateHistoryGenerateContent( + _history.followedBy([message]), + _templateId, + inputs: inputs, + ); + if (response.candidates case [final candidate, ...]) { + _history.add(message); + final normalizedContent = candidate.content.role == null + ? Content('model', candidate.content.parts) + : candidate.content; + _history.add(normalizedContent); + } + return response; + } finally { + lock.release(); + } + } + + /// Sends [message] to the server template as a continuation of the chat + /// [history]. + /// + /// Returns a stream of responses, which may be chunks of a single aggregate + /// response. + /// + /// Prepends the history to the request and uses the provided model to + /// generate new content. + /// + /// When there are no candidates in the response, the [message] and response + /// are ignored and will not be recorded in the [history]. + Stream sendMessageStream(Content message, + {Map? inputs}) { + final controller = StreamController(sync: true); + _mutex.acquire().then((lock) async { + try { + final responses = _templateHistoryGenerateContentStream( + _history.followedBy([message]), + _templateId, + inputs: inputs, + ); + final content = []; + await for (final response in responses) { + if (response.candidates case [final candidate, ...]) { + content.add(candidate.content); + } + controller.add(response); + } + if (content.isNotEmpty) { + _history.add(message); + _history.add(historyAggregate(content)); + } + } catch (e, s) { + controller.addError(e, s); + } + lock.release(); + unawaited(controller.close()); + }); + return controller.stream; + } +} + +/// An extension on [TemplateGenerativeModel] that provides a `startChat` method. +extension StartTemplateChatExtension on TemplateGenerativeModel { + /// Starts a [TemplateChatSession] that will use this model to respond to messages. + /// + /// ```dart + /// final chat = model.startChat(); + /// final response = await chat.sendMessage(Content.text('Hello there.')); + /// print(response.text); + /// ``` + TemplateChatSession startChat(String templateId, {List? history}) => + TemplateChatSession._( + templateGenerateContentWithHistory, + templateGenerateContentWithHistoryStream, + templateId, + history ?? [], + ); +} diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart new file mode 100644 index 000000000000..d1b46213cffb --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_generative_model.dart @@ -0,0 +1,137 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// ignore_for_file: use_late_for_private_fields_and_variables +part of '../base_model.dart'; + +/// A generative model that connects to a remote server template. +@experimental +final class TemplateGenerativeModel extends BaseTemplateApiClientModel { + @internal + TemplateGenerativeModel.internal({ + required String location, + required FirebaseApp app, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + http.Client? httpClient, + }) : super( + serializationStrategy: useVertexBackend + ? VertexSerialization() + : DeveloperSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + httpClient: httpClient, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + TemplateGenerativeModel._({ + required String location, + required FirebaseApp app, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + http.Client? httpClient, + }) : super( + serializationStrategy: useVertexBackend + ? VertexSerialization() + : DeveloperSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + httpClient: httpClient, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + /// Generates content from a template with the given [templateId] and [inputs]. + /// + /// Sends a "templateGenerateContent" API request for the configured model. + @experimental + Future generateContent(String templateId, + {Map? inputs}) => + makeTemplateRequest(TemplateTask.templateGenerateContent, templateId, + inputs, null, _serializationStrategy.parseGenerateContentResponse); + + /// Generates a stream of content responding to [templateId] and [inputs]. + /// + /// Sends a "templateStreamGenerateContent" API request for the server template, + /// and waits for the response. + @experimental + Stream generateContentStream(String templateId, + {Map? inputs}) { + return streamTemplateRequest( + TemplateTask.templateStreamGenerateContent, + templateId, + inputs, + null, + _serializationStrategy.parseGenerateContentResponse); + } + + /// Generates content from a template with the given [templateId], [inputs] and + /// [history]. + @experimental + Future templateGenerateContentWithHistory( + Iterable history, String templateId, + {Map? inputs}) => + makeTemplateRequest(TemplateTask.templateGenerateContent, templateId, + inputs, history, _serializationStrategy.parseGenerateContentResponse); + + /// Generates a stream of content from a template with the given [templateId], + /// [inputs] and [history]. + @experimental + Stream templateGenerateContentWithHistoryStream( + Iterable history, String templateId, + {Map? inputs}) { + return streamTemplateRequest( + TemplateTask.templateStreamGenerateContent, + templateId, + inputs, + history, + _serializationStrategy.parseGenerateContentResponse); + } +} + +/// Returns a [TemplateGenerativeModel] using it's private constructor. +@experimental +TemplateGenerativeModel createTemplateGenerativeModel({ + required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, +}) => + TemplateGenerativeModel._( + app: app, + appCheck: appCheck, + useVertexBackend: useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + auth: auth, + location: location, + ); diff --git a/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart new file mode 100644 index 000000000000..9c1d8d28d30f --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/server_template/template_imagen_model.dart @@ -0,0 +1,97 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +part of '../base_model.dart'; + +/// An image model that connects to a remote server template. +@experimental +final class TemplateImagenModel extends BaseTemplateApiClientModel { + @internal + TemplateImagenModel.internal( + {required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, + http.Client? httpClient}) + : super( + serializationStrategy: VertexSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + httpClient: httpClient, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + TemplateImagenModel._( + {required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth}) + : super( + serializationStrategy: VertexSerialization(), + modelUri: useVertexBackend + ? _VertexUri(app: app, model: '', location: location) + : _GoogleAIUri(app: app, model: ''), + client: HttpApiClient( + apiKey: app.options.apiKey, + requestHeaders: BaseModel.firebaseTokens( + appCheck, auth, app, useLimitedUseAppCheckTokens)), + templateUri: useVertexBackend + ? _TemplateVertexUri(app: app, location: location) + : _TemplateGoogleAIUri(app: app), + ); + + /// Generates images from a template with the given [templateId] and [inputs]. + @experimental + Future> generateImages( + String templateId, + {Map? inputs}) => + makeTemplateRequest( + TemplateTask.templatePredict, + templateId, + inputs, + null, + (jsonObject) => + parseImagenGenerationResponse(jsonObject), + ); +} + +/// Returns a [TemplateImagenModel] using it's private constructor. +@experimental +TemplateImagenModel createTemplateImagenModel({ + required FirebaseApp app, + required String location, + required bool useVertexBackend, + bool? useLimitedUseAppCheckTokens, + FirebaseAppCheck? appCheck, + FirebaseAuth? auth, +}) => + TemplateImagenModel._( + app: app, + appCheck: appCheck, + auth: auth, + location: location, + useVertexBackend: useVertexBackend, + useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens, + ); diff --git a/packages/firebase_ai/firebase_ai/lib/src/utils/chat_utils.dart b/packages/firebase_ai/firebase_ai/lib/src/utils/chat_utils.dart new file mode 100644 index 000000000000..265759b0085d --- /dev/null +++ b/packages/firebase_ai/firebase_ai/lib/src/utils/chat_utils.dart @@ -0,0 +1,55 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import '../content.dart'; + +/// Aggregates a list of [Content] responses into a single [Content]. +/// +/// Includes all the [Content.parts] of every element of [contents], +/// and concatenates adjacent [TextPart]s into a single [TextPart], +/// even across adjacent [Content]s. +Content historyAggregate(List contents) { + assert(contents.isNotEmpty); + final role = contents.first.role ?? 'model'; + final textBuffer = StringBuffer(); + // If non-null, only a single text part has been seen. + TextPart? previousText; + final parts = []; + void addBufferedText() { + if (textBuffer.isEmpty) return; + if (previousText case final singleText?) { + parts.add(singleText); + previousText = null; + } else { + parts.add(TextPart(textBuffer.toString())); + } + textBuffer.clear(); + } + + for (final content in contents) { + for (final part in content.parts) { + if (part case TextPart(:final text)) { + if (text.isNotEmpty) { + previousText = textBuffer.isEmpty ? part : null; + textBuffer.write(text); + } + } else { + addBufferedText(); + parts.add(part); + } + } + } + addBufferedText(); + return Content(role, parts); +} diff --git a/packages/firebase_ai/firebase_ai/test/server_template_test.dart b/packages/firebase_ai/firebase_ai/test/server_template_test.dart new file mode 100644 index 000000000000..f4154d83ce4d --- /dev/null +++ b/packages/firebase_ai/firebase_ai/test/server_template_test.dart @@ -0,0 +1,283 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import 'dart:convert'; + +import 'package:firebase_ai/firebase_ai.dart'; +import 'package:firebase_core/firebase_core.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:http/http.dart' as http; +import 'package:http/testing.dart'; + +import 'mock.dart'; + +// A response for generateContent and generateContentStream. +final _arbitraryGenerateContentResponse = { + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': 'Some response'}, + ], + }, + }, + ], +}; + +// A response for Imagen's generateImages. +final _arbitraryImagenResponse = { + 'predictions': [ + { + 'mimeType': 'image/png', + 'bytesBase64Encoded': + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=' + } + ] +}; + +void main() { + setupFirebaseVertexAIMocks(); + late FirebaseApp app; + setUpAll(() async { + app = await Firebase.initializeApp(); + }); + + group('TemplateGenerativeModel', () { + const templateId = 'my-template'; + const location = 'us-central1'; + + TemplateGenerativeModel createModel(http.Client client, + {bool useVertexBackend = true}) { + // ignore: invalid_use_of_internal_member + return TemplateGenerativeModel.internal( + app: app, + location: location, + useVertexBackend: useVertexBackend, + httpClient: client, + ); + } + + test('generateContent can make successful request', () async { + final mockHttp = MockClient((request) async { + final body = jsonDecode(request.body) as Map; + expect(request.url.path, + endsWith('/templates/$templateId:templateGenerateContent')); + expect(body['inputs'], {'prompt': 'Some prompt'}); + return http.Response(jsonEncode(_arbitraryGenerateContentResponse), 200, + headers: {'content-type': 'application/json'}); + }); + + final model = createModel(mockHttp); + final response = await model + .generateContent(templateId, inputs: {'prompt': 'Some prompt'}); + expect(response.text, 'Some response'); + }); + + test('generateContentStream can make successful request', () async { + final mockHttp = MockClient((request) async { + final body = jsonDecode(request.body) as Map; + expect(request.url.path, + endsWith('/templates/$templateId:templateStreamGenerateContent')); + expect(body['inputs'], {'prompt': 'Some prompt'}); + final responsePayload = jsonEncode(_arbitraryGenerateContentResponse); + final stream = Stream.value(utf8.encode('data: $responsePayload')); + final streamedResponse = http.StreamedResponse(stream, 200, + headers: {'content-type': 'application/json'}); + return http.Response.fromStream(streamedResponse); + }); + + final model = createModel(mockHttp); + final responseStream = model + .generateContentStream(templateId, inputs: {'prompt': 'Some prompt'}); + final response = await responseStream.first; + expect(response.text, 'Some response'); + }); + + test('templateGenerateContentWithHistory includes history', () async { + final history = [ + Content.text('Hi!'), + Content.model([const TextPart('Hello there.')]), + ]; + final mockHttp = MockClient((request) async { + final body = jsonDecode(request.body) as Map; + final contents = body['history'] as List; + expect(contents, hasLength(2)); + expect(contents[0]['parts'][0]['text'], 'Hi!'); + expect(contents[1]['role'], 'model'); + return http.Response(jsonEncode(_arbitraryGenerateContentResponse), 200, + headers: {'content-type': 'application/json'}); + }); + final model = createModel(mockHttp); + await model.templateGenerateContentWithHistory(history, templateId); + }); + + test('templateGenerateContentWithHistoryStream includes history', () async { + final history = [ + Content.text('Hi!'), + Content.model([const TextPart('Hello there.')]), + ]; + final mockHttp = MockClient((request) async { + final body = jsonDecode(request.body) as Map; + final contents = body['history'] as List; + expect(contents, hasLength(2)); + expect(contents[0]['parts'][0]['text'], 'Hi!'); + expect(contents[1]['role'], 'model'); + final responsePayload = jsonEncode(_arbitraryGenerateContentResponse); + final stream = Stream.value(utf8.encode('data: $responsePayload')); + final streamedResponse = http.StreamedResponse(stream, 200, + headers: {'content-type': 'application/json'}); + return http.Response.fromStream(streamedResponse); + }); + final model = createModel(mockHttp); + final responseStream = + model.templateGenerateContentWithHistoryStream(history, templateId); + await responseStream.drain(); + }); + }); + + group('TemplateImagenModel', () { + const templateId = 'my-imagen-template'; + const location = 'us-central1'; + + TemplateImagenModel createModel(http.Client client, + {bool useVertexBackend = true}) { + // ignore: invalid_use_of_internal_member + return TemplateImagenModel.internal( + app: app, + location: location, + useVertexBackend: useVertexBackend, + httpClient: client, + ); + } + + test('generateImages can make successful request', () async { + final mockHttp = MockClient((request) async { + final body = jsonDecode(request.body) as Map; + expect(request.url.path, endsWith('/templates/$templateId:predict')); + expect(body['inputs'], {'prompt': 'A cat'}); + return http.Response(jsonEncode(_arbitraryImagenResponse), 200, + headers: {'content-type': 'application/json'}); + }); + final model = createModel(mockHttp); + final response = + await model.generateImages(templateId, inputs: {'prompt': 'A cat'}); + expect(response.images, hasLength(1)); + expect(response.images.first, isA()); + }); + }); + + group('TemplateChatSession', () { + const templateId = 'my-chat-template'; + late TemplateGenerativeModel model; + + test('sendMessage adds to history', () async { + final mockHttp = MockClient((request) async { + return http.Response(jsonEncode(_arbitraryGenerateContentResponse), 200, + headers: {'content-type': 'application/json'}); + }); + // ignore: invalid_use_of_internal_member + model = TemplateGenerativeModel.internal( + app: app, + location: 'us-central1', + useVertexBackend: true, + httpClient: mockHttp, + ); + + final chat = model.startChat(templateId); + expect(chat.history, isEmpty); + final response = await chat.sendMessage(Content.text('Hi')); + expect(chat.history, hasLength(2)); + expect(chat.history.first.parts.first, isA()); + expect((chat.history.first.parts.first as TextPart).text, 'Hi'); + expect(chat.history.last.role, 'model'); + expect(chat.history.last.parts.first, isA()); + expect((chat.history.last.parts.first as TextPart).text, response.text); + }); + + test('sendMessageStream adds to history', () async { + final mockHttp = MockClient((request) async { + final stream = Stream.fromIterable([ + 'data: ${jsonEncode({ + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': 'Some '}, + ], + }, + }, + ], + })}', + 'data: ${jsonEncode({ + 'candidates': [ + { + 'content': { + 'role': 'model', + 'parts': [ + {'text': 'response'}, + ], + }, + }, + ], + })}' + ].map(utf8.encode)); + final streamedResponse = http.StreamedResponse(stream, 200, + headers: {'content-type': 'application/json'}); + return http.Response.fromStream(streamedResponse); + }); + // ignore: invalid_use_of_internal_member + model = TemplateGenerativeModel.internal( + app: app, + location: 'us-central1', + useVertexBackend: true, + httpClient: mockHttp, + ); + final chat = model.startChat(templateId); + expect(chat.history, isEmpty); + final responseStream = chat.sendMessageStream(Content.text('Hi')); + final responses = await responseStream.toList(); + expect(responses, hasLength(2)); + expect(chat.history, hasLength(2)); + expect(chat.history.first.parts.first, isA()); + expect((chat.history.first.parts.first as TextPart).text, 'Hi'); + expect(chat.history.last.role, 'model'); + expect(chat.history.last.parts.first, isA()); + expect((chat.history.last.parts.first as TextPart).text, 'Some response'); + }); + + test('sendMessage with initial history', () async { + final mockHttp = MockClient((request) async { + return http.Response(jsonEncode(_arbitraryGenerateContentResponse), 200, + headers: {'content-type': 'application/json'}); + }); + // ignore: invalid_use_of_internal_member + model = TemplateGenerativeModel.internal( + app: app, + location: 'us-central1', + useVertexBackend: true, + httpClient: mockHttp, + ); + final history = [ + Content.text('Hi!'), + Content.model([const TextPart('Hello there.')]), + ]; + final chat = model.startChat(templateId, history: history); + expect(chat.history, hasLength(2)); + await chat.sendMessage(Content.text('How are you?')); + expect(chat.history, hasLength(4)); + }); + }); +}