Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a timeout for getting suggestions from the LMProvider #18234

Open
wants to merge 7 commits into
base: feature/llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions src/cascadia/QueryExtension/AzureLLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,31 +145,49 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Send the request
try
{
const auto response = _httpClient.SendRequestAsync(request).get();
// Parse out the suggestion from the response
const auto string{ response.Content().ReadAsStringAsync().get() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);

// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
if (_verifyModelIsValidHelper(jsonResult))
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(errorString))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
const auto errorObject = jsonResult.GetNamedObject(errorString);
message = errorObject.GetNamedString(messageString);
errorType = ErrorTypes::FromProvider;
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
if (_verifyModelIsValidHelper(jsonResult))
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageString);
message = messageObject.GetNamedString(contentString);
}
else
{
message = RS_(L"InvalidModelMessage");
errorType = ErrorTypes::InvalidModel;
}
}
}
else
{
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)
{
Expand Down
11 changes: 10 additions & 1 deletion src/cascadia/QueryExtension/ExtensionPalette.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,16 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation

if (_lmProvider)
{
result = _lmProvider.GetResponseAsync(promptCopy).get();
const auto asyncOperation = _lmProvider.GetResponseAsync(promptCopy);
if (asyncOperation.wait_for(std::chrono::seconds(15)) == AsyncStatus::Completed)
{
result = asyncOperation.GetResults();
}
else
{
asyncOperation.Cancel();
result = winrt::make<SystemResponse>(RS_(L"UnknownErrorMessage"), ErrorTypes::Unknown, winrt::hstring{});
}
}
else
{
Expand Down
70 changes: 57 additions & 13 deletions src/cascadia/QueryExtension/GithubCopilotLLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Make sure we are on the background thread for the http request
auto strongThis = get_strong();
co_await winrt::resume_background();
auto cancellationToken{ co_await winrt::get_cancellation_token() };

for (bool refreshAttempted = false;;)
{
Expand Down Expand Up @@ -276,19 +277,37 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
};

// Send the request
const auto jsonResult = co_await _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());
if (jsonResult.HasKey(errorKey))
const auto sendRequestOperation = _SendRequestReturningJson(_endpointUri, requestContent, WWH::HttpMethod::Post());

// if the caller cancels this operation, make sure to cancel the http request as well
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto jsonResult = sendRequestOperation.GetResults();
if (jsonResult.HasKey(errorKey))
{
const auto errorObject = jsonResult.GetNamedObject(errorKey);
message = errorObject.GetNamedString(messageKey);
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(choicesKey);
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
}
}
else
{
const auto choices = jsonResult.GetNamedArray(choicesKey);
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(messageKey);
message = messageObject.GetNamedString(contentKey);
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
break;
}
Expand All @@ -305,8 +324,23 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
break;
}

co_await _refreshAuthTokens();
refreshAttempted = true;
const auto refreshTokensAction = _refreshAuthTokens();
cancellationToken.callback([refreshTokensAction] {
refreshTokensAction.Cancel();
});
// allow up to 10 seconds for reauthentication
if (refreshTokensAction.wait_for(std::chrono::seconds(10)) == AsyncStatus::Completed)
{
refreshAttempted = true;
}
else
{
// if the refresh action takes too long, cancel it and return an error
refreshTokensAction.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
break;
}
}

// Also make a new entry in our jsonMessages list, so the AI knows the full conversation so far
Expand Down Expand Up @@ -334,7 +368,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation

try
{
const auto jsonResult = co_await _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
const auto reAuthOperation = _SendRequestReturningJson(accessTokenEndpoint, requestContent, WWH::HttpMethod::Post());
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([reAuthOperation] {
reAuthOperation.Cancel();
});
const auto jsonResult{ co_await reAuthOperation };

_authToken = jsonResult.GetNamedString(accessTokenKey);
_refreshToken = jsonResult.GetNamedString(refreshTokenKey);
Expand All @@ -360,7 +399,12 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
WWH::HttpRequestMessage request{ method, Uri{ uri } };
request.Content(content);

const auto response{ co_await _httpClient.SendRequestAsync(request) };
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});
const auto response{ co_await sendRequestOperation };
const auto string{ co_await response.Content().ReadAsStringAsync() };
_lastResponse = string;
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
Expand Down
42 changes: 30 additions & 12 deletions src/cascadia/QueryExtension/OpenAILLMProvider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,40 @@ namespace winrt::Microsoft::Terminal::Query::Extension::implementation
// Send the request
try
{
const auto response = co_await _httpClient.SendRequestAsync(request);
// Parse out the suggestion from the response
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
const auto sendRequestOperation = _httpClient.SendRequestAsync(request);

// if the caller cancels this operation, make sure to cancel the http request as well
auto cancellationToken{ co_await winrt::get_cancellation_token() };
cancellationToken.callback([sendRequestOperation] {
sendRequestOperation.Cancel();
});

if (sendRequestOperation.wait_for(std::chrono::seconds(5)) == AsyncStatus::Completed)
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
// Parse out the suggestion from the response
const auto response = sendRequestOperation.GetResults();
const auto string{ co_await response.Content().ReadAsStringAsync() };
const auto jsonResult{ WDJ::JsonObject::Parse(string) };
if (jsonResult.HasKey(L"error"))
{
const auto errorObject = jsonResult.GetNamedObject(L"error");
message = errorObject.GetNamedString(L"message");
errorType = ErrorTypes::FromProvider;
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
}
}
else
{
const auto choices = jsonResult.GetNamedArray(L"choices");
const auto firstChoice = choices.GetAt(0).GetObject();
const auto messageObject = firstChoice.GetNamedObject(L"message");
message = messageObject.GetNamedString(L"content");
// if the http request takes too long, cancel the http request and return an error
sendRequestOperation.Cancel();
message = RS_(L"UnknownErrorMessage");
errorType = ErrorTypes::Unknown;
}
}
catch (...)
Expand Down
2 changes: 1 addition & 1 deletion src/cascadia/QueryExtension/Resources/en-US/Resources.resw
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
<comment>The message presented to the user when they attempt to use the AI chat feature without providing an AI endpoint and key.</comment>
</data>
<data name="UnknownErrorMessage" xml:space="preserve">
<value>An error occurred. Your AI provider might not be correctly configured, or the service might be temporarily unavailable.</value>
<value>An error occurred. The service might be temporarily unavailable or there might be network connectivity issues.</value>
<comment>The error message presented to the user when we were unable to query the provided endpoint.</comment>
</data>
<data name="InvalidModelMessage" xml:space="preserve">
Expand Down
2 changes: 2 additions & 0 deletions src/cascadia/QueryExtension/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ TRACELOGGING_DECLARE_PROVIDER(g_hQueryExtensionProvider);

#include <winrt/Windows.Data.Json.h>

#include <chrono>

// Manually include til after we include Windows.Foundation to give it winrt superpowers
#include "til.h"

Expand Down
Loading