Skip to content

Commit

Permalink
Merge pull request #24 from sgomez/support-custom-provider
Browse files Browse the repository at this point in the history
Support custom provider
  • Loading branch information
sgomez authored Aug 27, 2024
2 parents 1918f79 + b298722 commit 1e47a34
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 443 deletions.
7 changes: 5 additions & 2 deletions examples/ai-core/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ embed-many_ollama-cosine-similarity:


# registry
.PHONY: registry registry-run registry-all registry_embed registry_stream-text
.PHONY: registry registry-run registry-all registry_embed registry_stream-text registry_stream-multimodal
registry: registry-run registry-all
registry-run:
echo - examples/registry:
registry-all: registry_embed registry_stream-text
registry-all: registry_embed registry_stream-text registry_stream-multimodal
registry_embed:
$(call RUN_EXAMPLE_TARGET,$@)
registry_stream-text:
$(call RUN_EXAMPLE_TARGET,$@)
registry_stream-multimodal:
$(call RUN_EXAMPLE_TARGET,$@)



# generate-object
Expand Down
16 changes: 13 additions & 3 deletions examples/ai-core/src/registry/setup-registry.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import { experimental_createProviderRegistry as createProviderRegistry } from 'ai'
import {
experimental_createProviderRegistry as createProviderRegistry,
experimental_customProvider as customProvider,
} from 'ai'
import { ollama } from 'ollama-ai-provider'

const myOllama = customProvider({
fallbackProvider: ollama,
languageModels: {
multimodal: ollama('llava'),
text: ollama('llama3.1'),
},
})

export const registry = createProviderRegistry({
// register provider with prefix and custom setup:
ollama,
ollama: myOllama,
})
31 changes: 31 additions & 0 deletions examples/ai-core/src/registry/stream-multimodal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#! /usr/bin/env -S pnpm tsx

import fs from 'node:fs'

import { streamText } from 'ai'
import { ollama } from 'ollama-ai-provider'

import { buildProgram } from '../tools/command'
import { registry } from './setup-registry'

async function main(model: Parameters<typeof ollama>[0]) {
const result = await streamText({
maxTokens: 512,
messages: [
{
content: [
{ text: 'Describe the image in detail.', type: 'text' },
{ image: fs.readFileSync('./data/comic-cat.png'), type: 'image' },
],
role: 'user',
},
],
model: registry.languageModel(model),
})

for await (const textPart of result.textStream) {
process.stdout.write(textPart)
}
}

buildProgram('ollama:multimodal', main).catch(console.error)
2 changes: 1 addition & 1 deletion examples/ai-core/src/registry/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ async function main(model: Parameters<typeof ollama>[0]) {
}
}

buildProgram('ollama:llama3', main).catch(console.error)
buildProgram('ollama:text', main).catch(console.error)
6 changes: 3 additions & 3 deletions packages/ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
"author": "Sergio Gómez Bachiller <[email protected]>",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider": "0.0.21",
"@ai-sdk/provider-utils": "1.0.16",
"@ai-sdk/provider": "0.0.22",
"@ai-sdk/provider-utils": "1.0.17",
"partial-json": "0.1.7"
},
"devDependencies": {
"@edge-runtime/vm": "^3.2.0",
"@types/node": "^18.19.43",
"@types/node": "^18.19.46",
"tsup": "^8.2.4",
"typescript": "5.5.4",
"zod": "3.23.8"
Expand Down
3 changes: 2 additions & 1 deletion packages/ollama/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './ollama-facade'
export * from './ollama-provider'
export type { OllamaProvider, OllamaProviderSettings } from './ollama-provider'
export { createOllama, ollama } from './ollama-provider'
28 changes: 19 additions & 9 deletions packages/ollama/src/ollama-provider.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { EmbeddingModelV1, LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'
import { withoutTrailingSlash } from '@ai-sdk/provider-utils'

import { OllamaChatLanguageModel } from '@/ollama-chat-language-model'
Expand All @@ -8,31 +9,39 @@ import {
OllamaEmbeddingSettings,
} from '@/ollama-embedding-settings'

export interface OllamaProvider {
(
modelId: OllamaChatModelId,
settings?: OllamaChatSettings,
): OllamaChatLanguageModel
export interface OllamaProvider extends ProviderV1 {
(modelId: OllamaChatModelId, settings?: OllamaChatSettings): LanguageModelV1

chat(
modelId: OllamaChatModelId,
settings?: OllamaChatSettings,
): OllamaChatLanguageModel
): LanguageModelV1

/**
* @deprecated Use `textEmbeddingModel` instead.
*/
embedding(
modelId: OllamaEmbeddingModelId,
settings?: OllamaEmbeddingSettings,
): OllamaEmbeddingModel
): EmbeddingModelV1<string>

languageModel(
modelId: OllamaChatModelId,
settings?: OllamaChatSettings,
): OllamaChatLanguageModel
): LanguageModelV1

/**
* @deprecated Use `textEmbeddingModel()` instead.
*/
textEmbedding(
modelId: OllamaEmbeddingModelId,
settings?: OllamaEmbeddingSettings,
): OllamaEmbeddingModel
): EmbeddingModelV1<string>

textEmbeddingModel(
modelId: OllamaEmbeddingModelId,
settings?: OllamaEmbeddingSettings,
): EmbeddingModelV1<string>
}

export interface OllamaProviderSettings {
Expand Down Expand Up @@ -104,6 +113,7 @@ export function createOllama(
provider.embedding = createEmbeddingModel
provider.languageModel = createChatModel
provider.textEmbedding = createEmbeddingModel
provider.textEmbeddingModel = createEmbeddingModel

return provider as OllamaProvider
}
Expand Down
Loading

0 comments on commit 1e47a34

Please sign in to comment.