Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: avoid overhead when using syn custom protocol #1457

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ pub struct WebViewAttributes<'a> {
/// locate your files in those directories. For more information, see [Loading in-app content](https://developer.android.com/guide/webapps/load-local-content) page.
/// - iOS: To get the path of your assets, you can call [`CFBundle::resources_path`](https://docs.rs/core-foundation/latest/core_foundation/bundle/struct.CFBundle.html#method.resources_path). So url like `wry://assets/index.html` could get the html file in assets directory.
pub custom_protocols:
HashMap<String, Box<dyn Fn(WebViewId, Request<Vec<u8>>) -> Response<Cow<'static, [u8]>>>>,

/// Same as [`Self::custom_protocols`] but with an asynchronous responder.
pub async_custom_protocols:
HashMap<String, Box<dyn Fn(WebViewId, Request<Vec<u8>>, RequestAsyncResponder)>>,

/// The IPC handler to receive the message from Javascript on webview
Expand Down Expand Up @@ -605,6 +609,7 @@ impl Default for WebViewAttributes<'_> {
html: None,
initialization_scripts: Default::default(),
custom_protocols: Default::default(),
async_custom_protocols: Default::default(),
ipc_handler: None,
drag_drop_handler: None,
navigation_handler: None,
Expand Down Expand Up @@ -848,17 +853,18 @@ impl<'a> WebViewBuilder<'a> {
context.register_custom_protocol(name.clone())?;
}

if b.attrs.custom_protocols.iter().any(|(n, _)| n == &name) {
if b
.attrs
.custom_protocols
.iter()
.map(|c| c.0)
.chain(b.attrs.async_custom_protocols.iter().map(|c| c.0))
.any(|n| n == &name)
{
return Err(Error::DuplicateCustomProtocol(name));
}

b.attrs.custom_protocols.insert(
name,
Box::new(move |id, request, responder| {
let http_response = handler(id, request);
responder.respond(http_response);
}),
);
b.attrs.custom_protocols.insert(name, Box::new(handler));

Ok(b)
})
Expand Down Expand Up @@ -900,11 +906,20 @@ impl<'a> WebViewBuilder<'a> {
context.register_custom_protocol(name.clone())?;
}

if b.attrs.custom_protocols.iter().any(|(n, _)| n == &name) {
if b
.attrs
.custom_protocols
.iter()
.map(|c| c.0)
.chain(b.attrs.async_custom_protocols.iter().map(|c| c.0))
.any(|n| n == &name)
{
return Err(Error::DuplicateCustomProtocol(name));
}

b.attrs.custom_protocols.insert(name, Box::new(handler));
b.attrs
.async_custom_protocols
.insert(name, Box::new(handler));

Ok(b)
})
Expand Down
168 changes: 100 additions & 68 deletions src/webview2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ impl InnerWebView {
.iter()
.map(|n| n.0.clone())
.collect();
if !attributes.custom_protocols.is_empty() {
if !attributes.custom_protocols.is_empty() || !attributes.async_custom_protocols.is_empty() {
unsafe {
Self::attach_custom_protocol_handler(
&webview,
Expand Down Expand Up @@ -846,89 +846,121 @@ impl InnerWebView {
webview.AddWebResourceRequestedFilter(&filter, COREWEBVIEW2_WEB_RESOURCE_CONTEXT_ALL)?;
}

for name in attributes.async_custom_protocols.keys() {
// WebView2 supports non-standard protocols only on Windows 10+, so we have to use this workaround
// See https://github.com/MicrosoftEdge/WebView2Feedback/issues/73
let filter = HSTRING::from(format!("{scheme}://{name}.*"));
webview.AddWebResourceRequestedFilter(&filter, COREWEBVIEW2_WEB_RESOURCE_CONTEXT_ALL)?;
}

let env = env.clone();
let custom_protocols = std::mem::take(&mut attributes.custom_protocols);
let async_custom_protocols = std::mem::take(&mut attributes.async_custom_protocols);
let main_thread_id = std::thread::current().id();

webview.add_WebResourceRequested(
&WebResourceRequestedEventHandler::create(Box::new(move |_, args| {
let Some(args) = args else {
return Ok(());
};
let handler = WebResourceRequestedEventHandler::create(Box::new(move |_, args| {
let Some(args) = args else {
return Ok(());
};

#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();

// Request uri
let webview_request = args.Request()?;
// Request
let webview_request = args.Request()?;

// Request uri
let uri = {
let mut uri = PWSTR::null();
webview_request.Uri(&mut uri)?;
take_pwstr(uri)
// Request uri
let uri = {
let mut uri = PWSTR::null();
webview_request.Uri(&mut uri)?;
take_pwstr(uri)
};
#[cfg(feature = "tracing")]
span.record("uri", &uri);

// check if normal custom protocol
if let Some((protocol, handler)) = custom_protocols
.iter()
.find(|(protocol, _)| is_custom_protocol_uri(&uri, scheme, protocol))
{
let request = match Self::prepare_request(scheme, protocol, &webview_request, &uri) {
Ok(req) => req,
Err(e) => {
let err_response = Self::prepare_web_request_err(&env, e)?;
args.SetResponse(&err_response)?;
return Ok(());
}
};
#[cfg(feature = "tracing")]
span.record("uri", &uri);

if let Some((custom_protocol, custom_protocol_handler)) = custom_protocols
.iter()
.find(|(protocol, _)| is_custom_protocol_uri(&uri, scheme, protocol))
{
let request = match Self::prepare_request(scheme, custom_protocol, &webview_request, &uri)
{
Ok(req) => req,
Err(e) => {
let err_response = Self::prepare_web_request_err(&env, e)?;
args.SetResponse(&err_response)?;
return Ok(());

let response = handler(&webview_id, request);

match Self::prepare_web_request_response(&env, &response) {
Ok(response) => {
let _ = args.SetResponse(&response);
}
Err(e) => {
if let Ok(err_response) = Self::prepare_web_request_err(&env, e) {
let _ = args.SetResponse(&err_response);
}
};
}
}
} else if let Some((protocol, async_handler)) = async_custom_protocols // then try async protocols
.iter()
.find(|(protocol, _)| is_custom_protocol_uri(&uri, scheme, protocol))
{
let request = match Self::prepare_request(scheme, protocol, &webview_request, &uri) {
Ok(req) => req,
Err(e) => {
let err_response = Self::prepare_web_request_err(&env, e)?;
args.SetResponse(&err_response)?;
return Ok(());
}
};

let env = env.clone();
let deferral = args.GetDeferral();
let env = env.clone();
let deferral = args.GetDeferral();

let async_responder = Box::new(move |sent_response| {
let handler = move || {
match Self::prepare_web_request_response(&env, &sent_response) {
Ok(response) => {
let _ = args.SetResponse(&response);
}
Err(e) => {
if let Ok(err_response) = Self::prepare_web_request_err(&env, e) {
let _ = args.SetResponse(&err_response);
}
}
let async_responder = Box::new(move |sent_response| {
let handler = move || {
match Self::prepare_web_request_response(&env, &sent_response) {
Ok(response) => {
let _ = args.SetResponse(&response);
}

if let Ok(deferral) = &deferral {
let _ = deferral.Complete();
Err(e) => {
if let Ok(err_response) = Self::prepare_web_request_err(&env, e) {
let _ = args.SetResponse(&err_response);
}
}
};
}

if std::thread::current().id() == main_thread_id {
handler();
} else {
Self::dispatch_handler(hwnd, handler);
if let Ok(deferral) = &deferral {
let _ = deferral.Complete();
}
});
};

#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
custom_protocol_handler(
&webview_id,
request,
RequestAsyncResponder {
responder: async_responder,
},
);
}
if std::thread::current().id() == main_thread_id {
handler();
} else {
Self::dispatch_handler(hwnd, handler);
}
});

Ok(())
})),
token,
)?;
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();
async_handler(
&webview_id,
request,
RequestAsyncResponder {
responder: async_responder,
},
);
}

Ok(())
}));

webview.add_WebResourceRequested(&handler, token)?;

Self::attach_main_thread_dispatcher(hwnd);

Expand Down Expand Up @@ -1492,7 +1524,7 @@ impl InnerWebView {
cookie.Expires(&mut expires)?;

let expires = match expires {
-1.0 | _ if is_session.as_bool() => Some(cookie::Expiration::Session),
datetime if datetime == -1.0 || is_session.as_bool() => Some(cookie::Expiration::Session),
datetime => cookie::time::OffsetDateTime::from_unix_timestamp(datetime as _)
.ok()
.map(cookie::Expiration::DateTime),
Expand Down
Loading