Skip to content

Commit

Permalink
fix(macOS): Improve custom URL scheme handler robustness (#1440)
Browse files Browse the repository at this point in the history
* Improve custom URL scheme handler robustness and error handling

* re-add comments

---------

Co-authored-by: FabianLars <[email protected]>
  • Loading branch information
UdaraJay and FabianLars authored Jan 24, 2025
1 parent 5363d9b commit 89e9a0d
Showing 1 changed file with 71 additions and 41 deletions.
112 changes: 71 additions & 41 deletions src/wkwebview/class/url_scheme_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ extern "C" fn start_task(
) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();

let task_key = task.hash(); // hash by task object address
let task_uuid = webview.add_custom_task_key(task_key);
Expand Down Expand Up @@ -122,7 +122,6 @@ extern "C" fn start_task(
if let Some(all_headers) = all_headers {
for current_header in all_headers.allKeys().to_vec() {
let header_value = all_headers.valueForKey(current_header).unwrap();

// inject the header into the request
http_request = http_request.header(current_header.to_string(), header_value.to_string());
}
Expand All @@ -145,37 +144,52 @@ extern "C" fn start_task(
task.didFinish();
};

fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
// Consolidate checks before calling into `did*` methods.
let validate = || -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;
Ok(())
};

// Perform an upfront validation
if let Err(e) = validate() {
#[cfg(feature = "tracing")]
tracing::warn!("Task invalid before sending response: {:?}", e);
return; // If invalid, return early without calling task methods.
}

unsafe fn response(
Expand All @@ -189,7 +203,9 @@ extern "C" fn start_task(
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
// Validate
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
Expand All @@ -200,7 +216,6 @@ extern "C" fn start_task(
let wanted_version = format!("{:#?}", sent_response.version());

let mut headers = NSMutableDictionary::new();

if let Some(mime) = wanted_mime {
headers.insert_id(
NSString::from_str(CONTENT_TYPE.as_str()).as_ref(),
Expand Down Expand Up @@ -232,34 +247,42 @@ extern "C" fn start_task(
)
.unwrap();

// Re-validate before calling didReceiveResponse
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

// Use map_err to convert Option<Retained<Exception>> to crate::Error
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveResponse(&response);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data = NSData::alloc();
// MIGRATE NOTE: we copied the content to the NSData because content will be freed
// when out of scope but NSData will also free the content when it's done and cause doube free.
let data = NSData::initWithBytes_length(data, bytes, content.len());
let data = NSData::initWithBytes_length(
data,
content.as_ptr() as *mut c_void,
content.len(),
);

// Check validity again
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveData(&data);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didFinish();
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

{
let ids = WEBVIEW_IDS.lock().unwrap();
Expand All @@ -272,15 +295,21 @@ extern "C" fn start_task(
}
}

let _ = response(
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

if let Err(e) = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url.clone(),
sent_response,
);
) {
#[cfg(feature = "tracing")]
tracing::error!("Error responding to task: {:?}", e);
}
});

#[cfg(feature = "tracing")]
Expand All @@ -301,6 +330,7 @@ extern "C" fn start_task(
}
}
}

extern "C" fn stop_task(
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
Expand Down

0 comments on commit 89e9a0d

Please sign in to comment.