Skip to content

Commit

Permalink
Merge pull request #96 from UN-OCHA/develop
Browse files Browse the repository at this point in the history
v1.7.1
  • Loading branch information
orakili authored Nov 13, 2024
2 parents 429d086 + 43635dd commit 2ff9e6a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 23 deletions.
8 changes: 8 additions & 0 deletions modules/ocha_ai_tag/src/Services/OchaAiTagTagger.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use Drupal\Core\Cache\Cache;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Config\ConfigFactoryInterface;
use Drupal\Core\Config\ImmutableConfig;
use Drupal\Core\Database\Connection;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;
use Drupal\Core\Session\AccountProxyInterface;
Expand Down Expand Up @@ -37,6 +38,13 @@ class OchaAiTagTagger extends OchaAiChat {
*/
protected CacheBackendInterface $cacheBackend;

/**
* The AI tagger config.
*
* @var \Drupal\Core\Config\ImmutableConfig
*/
protected ImmutableConfig $config;

/**
* Vocabulary mapping.
*
Expand Down
18 changes: 18 additions & 0 deletions src/Plugin/CompletionPluginInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ interface CompletionPluginInterface {
*/
public function answer(string $question, string $context): string;

/**
* Perform a completion query.
*
* @param string $prompt
* Prompt.
* @param string $system_prompt
* Optional system prompt.
* @param array $parameters
* Optional parameters for the payload: max_tokens, temperature, top_p.
* @param bool $raw
* Whether to return the raw output text or let the plugin do some
* processing if any.
*
* @return string|null
* The model output text or NULL in case of error when querying the model.
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $raw = TRUE): ?string;

/**
* Get the prompt template.
*
Expand Down
40 changes: 29 additions & 11 deletions src/Plugin/ocha_ai/Completion/AwsBedrock.php
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,20 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->query($prompt, raw: FALSE) ?? '';
}

/**
* {@inheritdoc}
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $raw = TRUE): ?string {
if (empty($prompt)) {
return '';
}

$payload = [
'accept' => 'application/json',
'body' => json_encode($this->generateRequestBody($prompt)),
'body' => json_encode($this->generateRequestBody($prompt, $parameters)),
'contentType' => 'application/json',
'modelId' => $this->getPluginSetting('model'),
];
Expand All @@ -67,20 +78,24 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->parseResponseBody($data);
return $this->parseResponseBody($data, $raw);
}

/**
* Generate the request body for the completion.
*
* @param string $prompt
* Prompt.
* @param array $parameters
* Parameters for the payload: max_tokens, temperature, top_p.
*
* @return array
* Request body.
*/
protected function generateRequestBody(string $prompt): array {
$max_tokens = (int) $this->getPluginSetting('max_tokens', 512);
protected function generateRequestBody(string $prompt, array $parameters = []): array {
$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

switch ($this->getPluginSetting('model')) {
case 'amazon.titan-text-express-v1':
Expand All @@ -90,16 +105,16 @@ protected function generateRequestBody(string $prompt): array {
'maxTokenCount' => $max_tokens,
// @todo adjust based on the prompt?
'stopSequences' => [],
'temperature' => 0.0,
'topP' => 0.9,
'temperature' => $temperature,
'topP' => $top_p,
],
];

case 'anthropic.claude-instant-v1':
return [
'prompt' => "\n\nHuman:$prompt\n\nAssistant:",
'temperature' => 0.0,
'top_p' => 0.9,
'temperature' => $temperature,
'top_p' => $top_p,
'top_k' => 0,
'max_tokens_to_sample' => $max_tokens,
'stop_sequences' => ["\n\nHuman:"],
Expand All @@ -109,8 +124,8 @@ protected function generateRequestBody(string $prompt): array {
case 'cohere.command-light-text-v14':
return [
'prompt' => $prompt,
'temperature' => 0.0,
'p' => 0.9,
'temperature' => $temperature,
'p' => $top_p,
'k' => 0.0,
'max_tokens' => $max_tokens,
'stop_sequences' => [],
Expand All @@ -129,11 +144,14 @@ protected function generateRequestBody(string $prompt): array {
*
* @param array $data
* Decoded response.
* @param bool $raw
* Whether to return the raw output text or let the plugin do some
* processing if any.
*
* @return string
* The generated text.
*/
protected function parseResponseBody(array $data): string {
protected function parseResponseBody(array $data, bool $raw = TRUE): string {
switch ($this->getPluginSetting('model')) {
case 'amazon.titan-text-express-v1':
return trim($data['results'][0]['outputText'] ?? '');
Expand Down
16 changes: 11 additions & 5 deletions src/Plugin/ocha_ai/Completion/AwsBedrockTitanTextPremierV1.php
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,36 @@ public function getPromptTemplate(): string {
/**
* {@inheritdoc}
*/
protected function generateRequestBody(string $prompt): array {
$max_tokens = (int) $this->getPluginSetting('max_tokens', 512);
protected function generateRequestBody(string $prompt, array $parameters = []): array {
$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

return [
'inputText' => $prompt,
'textGenerationConfig' => [
'maxTokenCount' => $max_tokens,
// @todo adjust based on the prompt?
'stopSequences' => [],
'temperature' => 0.0,
'topP' => 0.9,
'temperature' => $temperature,
'topP' => $top_p,
],
];
}

/**
* {@inheritdoc}
*/
protected function parseResponseBody(array $data): string {
protected function parseResponseBody(array $data, bool $raw = TRUE): string {
$response = trim($data['results'][0]['outputText'] ?? '');
if ($response === '') {
return '';
}

if ($raw) {
return $response;
}

// Extract the answer.
$start = mb_strpos($response, '<answer>');
$end = mb_strpos($response, '</answer>');
Expand Down
29 changes: 22 additions & 7 deletions src/Plugin/ocha_ai/Completion/AzureOpenAi.php
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,38 @@ public function answer(string $question, string $context): string {
return '';
}

return $this->query($question, $prompt) ?? '';
}

/**
* {@inheritdoc}
*/
public function query(string $prompt, string $system_prompt = '', array $parameters = [], bool $process = FALSE): ?string {
if (empty($prompt)) {
return '';
}

$max_tokens = (int) ($parameters['max_tokens'] ?? $this->getPluginSetting('max_tokens', 512));
$temperature = (float) ($parameters['temperature'] ?? 0.0);
$top_p = (float) ($parameters['top_p'] ?? 0.9);

$messages = [
[
'role' => 'system',
'content' => $prompt,
'content' => $system_prompt ?: 'You are a helpful assistant.',
],
[
'role' => 'user',
'content' => $question,
'content' => $prompt,
],
];

$payload = [
'model' => $this->getPluginSetting('model'),
'messages' => $messages,
'temperature' => 0.0,
'top_p' => 0.9,
'max_tokens' => (int) $this->getPluginSetting('max_tokens', 512),
'temperature' => $temperature,
'top_p' => $top_p,
'max_tokens' => $max_tokens,
];

try {
Expand All @@ -70,7 +85,7 @@ public function answer(string $question, string $context): string {
$this->getLogger()->error(strtr('Completion request failed with: @error.', [
'@error' => $exception->getMessage(),
]));
return '';
return NULL;
}

try {
Expand All @@ -80,7 +95,7 @@ public function answer(string $question, string $context): string {
$this->getLogger()->error(strtr('Unable to retrieve completion result data: @error.', [
'@error' => $exception->getMessage(),
]));
return '';
return NULL;
}

return trim($data['choices'][0]['message']['content'] ?? '');
Expand Down

0 comments on commit 2ff9e6a

Please sign in to comment.