Skip to content

Commit

Permalink
code snippet provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lukka committed Dec 7, 2024
1 parent ee71e1f commit 22235f3
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 2 deletions.
26 changes: 24 additions & 2 deletions Extension/src/LanguageServer/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ import {
} from './codeAnalysis';
import { Location, TextEdit, WorkspaceEdit } from './commonTypes';
import * as configs from './configurations';
import { CopilotCompletionContextProvider } from './copilotCompletionContextProvider';
import { DataBinding } from './dataBinding';
import { cachedEditorConfigSettings, getEditorConfigSettings } from './editorConfig';
import { CppSourceStr, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { CppSourceStr, SnippetEntry, clients, configPrefix, updateLanguageConfigurations, usesCrashHandler, watchForCrashes } from './extension';
import { LocalizeStringParams, getLocaleId, getLocalizedString } from './localization';
import { PersistentFolderState, PersistentWorkspaceState } from './persistentState';
import { RequestCancelled, ServerCancelled, createProtocolFilter } from './protocolFilter';
Expand Down Expand Up @@ -554,6 +555,15 @@ export interface ProjectContextResult {
fileContext: FileContextResult;
}

export interface CompletionContextsResult {
context: SnippetEntry[];
}

export interface CompletionContextParams {
file: string;
caretOffset: number;
}

// Requests
const PreInitializationRequest: RequestType<void, string, void> = new RequestType<void, string, void>('cpptools/preinitialize');
const InitializationRequest: RequestType<CppInitializationParams, void, void> = new RequestType<CppInitializationParams, void, void>('cpptools/initialize');
Expand All @@ -575,6 +585,7 @@ const ChangeCppPropertiesRequest: RequestType<CppPropertiesParams, void, void> =
const IncludesRequest: RequestType<GetIncludesParams, GetIncludesResult, void> = new RequestType<GetIncludesParams, GetIncludesResult, void>('cpptools/getIncludes');
const CppContextRequest: RequestType<TextDocumentIdentifier, ChatContextResult, void> = new RequestType<TextDocumentIdentifier, ChatContextResult, void>('cpptools/getChatContext');
const ProjectContextRequest: RequestType<TextDocumentIdentifier, ProjectContextResult, void> = new RequestType<TextDocumentIdentifier, ProjectContextResult, void>('cpptools/getProjectContext');
const CompletionContextRequest: RequestType<CompletionContextParams, CompletionContextsResult, void> = new RequestType<CompletionContextParams, CompletionContextsResult, void>('cpptools/getCompletionContext');

// Notifications to the server
const DidOpenNotification: NotificationType<DidOpenTextDocumentParams> = new NotificationType<DidOpenTextDocumentParams>('textDocument/didOpen');
Expand Down Expand Up @@ -807,6 +818,7 @@ export interface Client {
getIncludes(maxDepth: number): Promise<GetIncludesResult>;
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult>;
getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult>;
getCompletionContext(fileName: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult>;
}

export function createClient(workspaceFolder?: vscode.WorkspaceFolder): Client {
Expand Down Expand Up @@ -839,7 +851,7 @@ export class DefaultClient implements Client {
private settingsTracker: SettingsTracker;
private loggingLevel: number = 1;
private configurationProvider?: string;

private copilotCompletionProvider?: CopilotCompletionContextProvider;
public lastCustomBrowseConfiguration: PersistentFolderState<WorkspaceBrowseConfiguration | undefined> | undefined;
public lastCustomBrowseConfigurationProviderId: PersistentFolderState<string | undefined> | undefined;
public lastCustomBrowseConfigurationProviderVersion: PersistentFolderState<Version> | undefined;
Expand Down Expand Up @@ -1298,6 +1310,8 @@ export class DefaultClient implements Client {
this.semanticTokensProviderDisposable = vscode.languages.registerDocumentSemanticTokensProvider(util.documentSelector, this.semanticTokensProvider, semanticTokensLegend);
}

this.copilotCompletionProvider = await CopilotCompletionContextProvider.Create();

// Listen for messages from the language server.
this.registerNotifications();

Expand Down Expand Up @@ -1807,6 +1821,7 @@ export class DefaultClient implements Client {
if (diagnosticsCollectionIntelliSense) {
diagnosticsCollectionIntelliSense.delete(document.uri);
}
this.copilotCompletionProvider?.removeFile(uri);
openFileVersions.delete(uri);
}

Expand Down Expand Up @@ -2255,6 +2270,12 @@ export class DefaultClient implements Client {
() => this.languageClient.sendRequest(CppContextRequest, params, token), token);
}

public async getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> {
await withCancellation(this.ready, token);
return DefaultClient.withLspCancellationHandling(
() => this.languageClient.sendRequest(CompletionContextRequest, { file: file.toString(), caretOffset }, token), token);
}

/**
* a Promise that can be awaited to know when it's ok to proceed.
*
Expand Down Expand Up @@ -4159,4 +4180,5 @@ class NullClient implements Client {
getIncludes(maxDepth: number): Promise<GetIncludesResult> { return Promise.resolve({} as GetIncludesResult); }
getChatContext(uri: vscode.Uri, token: vscode.CancellationToken): Promise<ChatContextResult> { return Promise.resolve({} as ChatContextResult); }
getProjectContext(uri: vscode.Uri): Promise<ProjectContextResult> { return Promise.resolve({} as ProjectContextResult); }
getCompletionContext(file: vscode.Uri, caretOffset: number, token: vscode.CancellationToken): Promise<CompletionContextsResult> { return Promise.resolve({} as CompletionContextsResult); }
}
185 changes: 185 additions & 0 deletions Extension/src/LanguageServer/copilotCompletionContextProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/* --------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
import * as vscode from 'vscode';
import { DocumentSelector } from 'vscode-languageserver-protocol';
import { getOutputChannelLogger, Logger } from '../logger';
import * as telemetry from '../telemetry';
import { CopilotContextTelemetry } from './copilotContextTelemetry';
import { getCopilotApi } from './copilotProviders';
import { clients } from './extension';
import { CodeSnippet, CompletionContext, ContextProviderApiV1, ContextResolver } from './tmp/contextProviderV1';

class DefaultValueFallback extends Error {
static readonly DefaultValue = "DefaultValue";
constructor() { super(DefaultValueFallback.DefaultValue); }
}

class CancellationError extends Error {
static readonly Cancelled = "Cancelled";
constructor() { super(CancellationError.Cancelled); }
}

// Mutually exclusive values for the kind of snippets. They either are:
// - computed.
// - obtained from the cache.
// - missing and the computation is taking too long and no cache is present (cache miss). The value
// is asynchronously computed and stored in cache.
// - the token is signaled as cancelled, in which case all the operations are aborted.
// - an unknown state.
enum SnippetsKind {
Computed = 'computed',
GotFromCache = 'gotFromCacheHit',
MissingCacheMiss = 'missingCacheMiss',
Cancelled = 'cancelled',
Unknown = 'unknown'
}

export class CopilotCompletionContextProvider implements ContextResolver<CodeSnippet> {
private static readonly providerId = 'cppTools';
private readonly completionContextCache: Map<string, CodeSnippet[]> = new Map<string, CodeSnippet[]>();
private static readonly defaultCppDocumentSelector: DocumentSelector = [{ language: 'cpp' }, { language: 'c' }, { language: 'cuda-cpp' }];
private static readonly defaultTimeBudgetFactor: number = 0.5;
private completionContextCancellation = new vscode.CancellationTokenSource();

// Get the default value if the timeout expires, but throws an exception if the token is cancelled.
private async waitForCompletionWithTimeoutAndCancellation<T>(promise: Promise<T>, defaultValue: T | undefined,
timeout: number, token: vscode.CancellationToken): Promise<[T | undefined, SnippetsKind]> {
const defaultValuePromise = new Promise<T>((resolve, reject) => setTimeout(() => {
if (token.isCancellationRequested) {
reject(new CancellationError());
} else {
reject(new DefaultValueFallback());
}
}, timeout));
const cancellationPromise = new Promise<T>((_, reject) => {
token.onCancellationRequested(() => {
reject(new CancellationError());
});
});
let snippetsOrNothing: T | undefined;
try {
snippetsOrNothing = await Promise.race([promise, cancellationPromise, defaultValuePromise]);
} catch (e) {
if (e instanceof DefaultValueFallback) {
return [defaultValue, defaultValue !== undefined ? SnippetsKind.GotFromCache : SnippetsKind.MissingCacheMiss];
} else if (e instanceof CancellationError) {
return [undefined, SnippetsKind.Cancelled];
} else {
throw e;
}
}

return [snippetsOrNothing, SnippetsKind.Computed];
}

// Get the completion context with a timeout and a cancellation token.
// The cancellationToken indicates that the value should not be returned nor cached.
private async getCompletionContextWithCancellation(documentUri: string, caretOffset: number,
startTime: number, out: Logger, telemetry: CopilotContextTelemetry, token: vscode.CancellationToken): Promise<CodeSnippet[]> {
try {
const docUri = vscode.Uri.parse(documentUri);
const snippets = await clients.getClientFor(docUri).getCompletionContext(docUri, caretOffset, token);

const codeSnippets = snippets.context.map((item) => {
if (token.isCancellationRequested) {
telemetry.addCancelledLate();
throw new CancellationError();
}
return {
importance: item.importance, uri: item.uri, value: item.text
};
});

this.completionContextCache.set(documentUri, codeSnippets);
const duration: number = performance.now() - startTime;
out.appendLine(`Copilot: getCompletionContextWithCancellation(): Cached in [ms]: ${duration}`);
telemetry.addSnippetCount(codeSnippets?.length);
telemetry.addCacheComputedElapsed(duration);

return codeSnippets;
} catch (e) {
const err = e as Error;
out.appendLine(`Copilot: getCompletionContextWithCancellation(): Error: '${err?.message}', stack '${err?.stack}`);
telemetry.addError();
return [];
}
}

private async fetchTimeBudgetFactor(context: CompletionContext): Promise<number> {
const budgetFactor = context.activeExperiments.get("CppToolsCopilotTimeBudget");
return (budgetFactor as number) !== undefined ? budgetFactor as number : CopilotCompletionContextProvider.defaultTimeBudgetFactor;
}

public static async Create() {
const copilotCompletionProvider = new CopilotCompletionContextProvider();
await copilotCompletionProvider.registerCopilotContextProvider();
return copilotCompletionProvider;
}

public removeFile(fileUri: string): void {
this.completionContextCache.delete(fileUri);
}

public async resolve(context: CompletionContext, copilotAborts: vscode.CancellationToken): Promise<CodeSnippet[]> {
const startTime = performance.now();
const out: Logger = getOutputChannelLogger();
const timeBudgetFactor = await this.fetchTimeBudgetFactor(context);
const telemetry = new CopilotContextTelemetry();
let codeSnippets: CodeSnippet[] | undefined;
let codeSnippetsKind: SnippetsKind = SnippetsKind.Unknown;
try {
this.completionContextCancellation.cancel();
this.completionContextCancellation = new vscode.CancellationTokenSource();
const docUri = context.documentContext.uri;
const cachedValue: CodeSnippet[] | undefined = this.completionContextCache.get(docUri.toString());
const snippetsPromise = this.getCompletionContextWithCancellation(docUri,
context.documentContext.offset, startTime, out, telemetry.fork(), this.completionContextCancellation.token);
[codeSnippets, codeSnippetsKind] = await this.waitForCompletionWithTimeoutAndCancellation(
snippetsPromise, cachedValue, context.timeBudget * timeBudgetFactor, copilotAborts);
if (codeSnippetsKind === SnippetsKind.Cancelled) {
const duration: number = performance.now() - startTime;
out.appendLine(`Copilot: getCompletionContext(): cancelled, elapsed time (ms) : ${duration}`);
telemetry.addCancelled();
telemetry.addCancellationElapsed(duration);
throw new CancellationError();
}
telemetry.addSnippetCount(codeSnippets?.length);
return codeSnippets ?? [];
} catch (e: any) {
telemetry.addError();
throw e;
} finally {
telemetry.addKind(codeSnippetsKind.toString());
const duration: number = performance.now() - startTime;
if (codeSnippets === undefined) {
out.appendLine(`Copilot: getCompletionContext(): no snkppets provided (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`);
} else {
out.appendLine(`Copilot: getCompletionContext(): provided ${codeSnippets?.length} snippets (${codeSnippetsKind.toString()}), elapsed time (ms): ${duration}`);
}
telemetry.addResolvedElapsed(duration);
telemetry.addCacheSize(this.completionContextCache.size);
// //?? TODO telemetry.file();
}

return [];
}

public async registerCopilotContextProvider(): Promise<void> {
try {
const isCustomSnippetProviderApiEnabled = await telemetry.isExperimentEnabled("CppToolsCustomSnippetsApi");
if (isCustomSnippetProviderApiEnabled) {
const contextAPI = (await getCopilotApi() as any).getContextProviderAPI('v1') as ContextProviderApiV1;
contextAPI.registerContextProvider({
id: CopilotCompletionContextProvider.providerId,
selector: CopilotCompletionContextProvider.defaultCppDocumentSelector,
resolver: this
});
}
} catch {
console.warn("Failed to register the Copilot Context Provider.");
telemetry.logCopilotEvent("registerCopilotContextProviderError", { "message": "Failed to register the Copilot Context Provider." });
}
}
}
72 changes: 72 additions & 0 deletions Extension/src/LanguageServer/copilotContextTelemetry.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* --------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All Rights Reserved.
* See 'LICENSE' in the project root for license information.
* ------------------------------------------------------------------------------------------ */
import { randomUUID } from 'crypto';
import * as telemetry from '../telemetry';

export class CopilotContextTelemetry {
private static readonly correlationIdKey = 'correlationId';
private static readonly copilotEventName = 'copilotContextProvider';
private readonly metrics: Record<string, number> = {};
private readonly properties: Record<string, string> = {};
private readonly id: string;
constructor(correlationId?: string) {
this.id = correlationId ?? randomUUID().toString();
}

private addMetric(key: string, value: number): void {
this.metrics[key] = value;
}

private addProperty(key: string, value: string): void {
this.properties[key] = value;
}

public addCancelled(): void {
this.addProperty('cancelled', 'true');
}

public addCancellationElapsed(duration: number): void {
this.addMetric('cancellationElapsedMs', duration);
}

public addCancelledLate(): void {
this.addProperty('cancelledLate', 'true');
}

public addError(): void {
this.addProperty('error', 'true');
}

public addKind(snippetsKind: string): void {
this.addProperty('kind', snippetsKind.toString());
}

public addResolvedElapsed(duration: number): void {
this.addMetric('overallResolveElapsedMs', duration);
}

public addCacheSize(size: number): void {
this.addMetric('cacheSize', size);
}

public addCacheComputedElapsed(duration: number): void {
this.addMetric('cacheComputedElapsedMs', duration);
}

// count can be undefined, in which case the count is set to -1 to indicate
// snippets are not available (different than having 0 snippets).
public addSnippetCount(count?: number) {
this.addMetric('snippetsCount', count ?? -1);
}

public file(): void {
this.properties[CopilotContextTelemetry.correlationIdKey] = this.id;
telemetry.logCopilotEvent(CopilotContextTelemetry.copilotEventName, this.properties, this.metrics);
}

public fork(): CopilotContextTelemetry {
return new CopilotContextTelemetry(this.id);
}
}
8 changes: 8 additions & 0 deletions Extension/src/LanguageServer/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ import { CppSettings } from './settings';
import { LanguageStatusUI, getUI } from './ui';
import { makeLspRange, rangeEquals, showInstallCompilerWalkthrough } from './utils';

export interface SnippetEntry {
uri: string;
text: string;
startLine: number;
endLine: number;
importance: number;
}

nls.config({ messageFormat: nls.MessageFormat.bundle, bundleFormat: nls.BundleFormat.standalone })();
const localize: nls.LocalizeFunc = nls.loadMessageBundle();
export const CppSourceStr: string = "C/C++";
Expand Down

0 comments on commit 22235f3

Please sign in to comment.