Skip to content

Commit

Permalink
Memory optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
OperKH committed Feb 18, 2024
1 parent bf7216c commit b275607
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 8 deletions.
7 changes: 6 additions & 1 deletion src/bot/bot.class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ConfigService } from '../config/config.service.js';

export class Bot {
private bot: Telegraf<IBotContext>;
private commands: Command[] = [];

constructor(private readonly configService: ConfigService) {
this.bot = new Telegraf<IBotContext>(this.configService.get('TOKEN'));
Expand All @@ -14,6 +15,7 @@ export class Bot {
registerCommands(commands: Array<{ new (bot: Telegraf<IBotContext>): Command }>) {
for (const Command of commands) {
const command = new Command(this.bot);
this.commands.push(command);
command.handle();
}
}
Expand All @@ -23,7 +25,10 @@ export class Bot {
console.log('Bot started');
}

stop(reason?: string) {
async stop(reason?: string) {
for (const command of this.commands) {
await command.dispose();
}
this.bot.stop(reason);
}
}
4 changes: 4 additions & 0 deletions src/bot/commands/classifyMessage.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ export class ClassifyMessageCommand extends Command {
}
});
}

async dispose() {
await this.aiService.dispose();
}
}
2 changes: 2 additions & 0 deletions src/bot/commands/command.class.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ export abstract class Command {
constructor(public readonly bot: Telegraf<IBotContext>) {}

abstract handle(): void;

abstract dispose(): Promise<void>;
}
2 changes: 2 additions & 0 deletions src/bot/commands/start.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ export class StartCommand extends Command {
ctx.reply("Привіт, я поки нічого не вмію, але обов'язково навчуся!");
});
}

async dispose() {}
}
45 changes: 38 additions & 7 deletions src/services/ai.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { env, pipeline } from '@xenova/transformers';
import { env, pipeline, TextClassificationPipeline, TranslationPipeline } from '@xenova/transformers';
env.cacheDir = './data/models';

export type TranslatorResponse = {
Expand Down Expand Up @@ -33,27 +33,58 @@ export class AIService {
return AIService.instance;
}

private sentimentAnalysisPipeline: Promise<TextClassificationPipeline> | null = null;
private toxicAnalysisPipeline: Promise<TextClassificationPipeline> | null = null;
private translationPipeline: Promise<TranslationPipeline> | null = null;

public async dispose() {
await Promise.all([
this.sentimentAnalysisPipeline?.then((c) => c.dispose()),
this.toxicAnalysisPipeline?.then((c) => c.dispose()),
this.translationPipeline?.then((c) => c.dispose()),
]);
}

private getSentimentAnalysisPipeline() {
if (!this.sentimentAnalysisPipeline) {
this.sentimentAnalysisPipeline = pipeline(
'sentiment-analysis',
'Xenova/distilbert-base-uncased-finetuned-sst-2-english',
);
}
return this.sentimentAnalysisPipeline;
}
private getToxicAnalysisPipeline() {
if (!this.toxicAnalysisPipeline) {
this.toxicAnalysisPipeline = pipeline('sentiment-analysis', 'Xenova/toxic-bert');
}
return this.toxicAnalysisPipeline;
}
private getTranslationPipeline() {
if (!this.translationPipeline) {
this.translationPipeline = pipeline('translation', 'Xenova/nllb-200-distilled-600M');
}
return this.translationPipeline;
}

public async sentimentAnalysis(text: string) {
const classifier = await pipeline('sentiment-analysis', 'Xenova/distilbert-base-uncased-finetuned-sst-2-english');
const classifier = await this.getSentimentAnalysisPipeline();
const output = await classifier(text);
await classifier.dispose();
console.log('sentimentAnalysis', text, output);
return output as DistilBertResponse[];
}

public async toxicAnalysis(text: string) {
const classifier = await pipeline('sentiment-analysis', 'Xenova/toxic-bert');
const classifier = await this.getToxicAnalysisPipeline();
const output = await classifier(text, { topk: 6 });
await classifier.dispose();
console.log('toxicAnalysis', text, output);
return output as ToxicBertResponse[];
}

public async translate(text: string) {
const translator = await pipeline('translation', 'Xenova/nllb-200-distilled-600M');
const translator = await this.getTranslationPipeline();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const output = await translator(text, { src_lang: 'rus_Cyrl', tgt_lang: 'eng_Latn' } as any);
await translator.dispose();
console.log('translate', text, output);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const [{ translation_text }] = output as TranslatorResponse[];
Expand Down

0 comments on commit b275607

Please sign in to comment.