Skip to content

Commit

Permalink
more flexible async implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
romancardenas committed Dec 19, 2023
1 parent 552063c commit 5b012c0
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 121 deletions.
1 change: 1 addition & 0 deletions riscv-peripheral/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ riscv-pac = { path = "../riscv-pac", version = "0.1.0" }
aclint-hal-async = ["embedded-hal-async"]

[package.metadata.docs.rs]
all-features = true
default-target = "riscv64imac-unknown-none-elf"
targets = [
"riscv32i-unknown-none-elf", "riscv32imc-unknown-none-elf", "riscv32imac-unknown-none-elf",
Expand Down
231 changes: 126 additions & 105 deletions riscv-peripheral/src/hal_async/aclint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
//! Asynchronous delay implementation for the (A)CLINT peripheral.
//!
//! # Note
//!
//! The asynchronous delay implementation for the (A)CLINT peripheral relies on the machine-level timer interrupts.
//! Therefore, it needs to schedule the machine-level timer interrupts via the [`MTIMECMP`] register assigned to the current HART.
//! Thus, the [`Delay`] instance must be created on the same HART that is used to call the asynchronous delay methods.
//!
//! # Requirements
//!
//! The following `extern "Rust"` functions must be implemented:
//!
//! - `fn _riscv_peripheral_aclint_mtimer(hart_id: usize) -> MTIMER`: This function returns the `MTIMER` register for the given HART ID.
//! - `fn _riscv_peripheral_aclint_push_timer(t: Timer) -> Result<(), Timer>`: This function pushes a new timer to a timer queue assigned to the given HART ID.
//! If it fails (e.g., the timer queue is full), it returns back the timer that failed to be pushed.
//! The logic of timer queues are application-specific and are not provided by this crate.
//! - `fn _riscv_peripheral_aclint_wake_timers(hart_id: usize, current_tick: u64) -> Option<u64>`:
//! This function pops all the expired timers from a timer queue assigned to the given HART ID and wakes their associated wakers.
//! The function returns the next [`MTIME`] tick at which the next timer expires. If the queue is empty, it returns `None`.
use crate::aclint::mtimer::{MTIME, MTIMECMP, MTIMER};
pub use crate::hal_async::delay::DelayNs;
Expand All @@ -21,101 +39,76 @@ extern "Rust" {
/// Tries to push a new timer to the timer queue assigned to the given HART ID.
/// If it fails (e.g., the timer queue is full), it returns back the timer that failed to be pushed.
///
/// # Note
///
/// the [`Delay`] reference allows to access the `MTIME` and `MTIMECMP` registers,
/// as well as handy information such as the HART ID or the clock frequency of the `MTIMER` peripheral.
///
/// # Safety
///
/// Do not call this function directly. It is only meant to be called by [`DelayAsync`].
fn _riscv_peripheral_push_timer(hart_id: usize, delay: &Delay, t: Timer) -> Result<(), Timer>;
fn _riscv_peripheral_aclint_push_timer(t: Timer) -> Result<(), Timer>;

/// Pops a expired timer from the timer queue assigned to the given HART ID.
/// If the queue is empty, it returns `Err(None)`.
/// Pops all the expired timers from the timer queue assigned to the given HART ID and wakes their associated wakers.
/// Once it is done, if the queue is empty, it returns `None`.
/// Alternatively, if the queue is not empty but the earliest timer has not expired yet,
/// it returns `Err(Some(next_expires))` where `next_expires` is the tick at which this timer expires.
/// it returns `Some(next_expires)` where `next_expires` is the tick at which this timer expires.
///
/// # Safety
///
/// It is extremely important that this function only returns a timer that has expired.
/// Otherwise, the timer will be lost and the waker will never be called.
///
/// Do not call this function directly. It is only meant to be called by [`MachineExternal`] and [`DelayAsync`].
fn _riscv_peripheral_pop_timer(hart_id: usize, current_tick: u64)
-> Result<Timer, Option<u64>>;
fn _riscv_peripheral_aclint_wake_timers(hart_id: usize, current_tick: u64) -> Option<u64>;
}

/// Machine-level timer interrupt handler.
/// This handler is triggered whenever the `MTIME` register reaches the value of the `MTIMECMP` register.
/// Machine-level timer interrupt handler. This handler is triggered whenever the `MTIME`
/// register reaches the value of the `MTIMECMP` register of the current HART.
#[no_mangle]
#[allow(non_snake_case)]
fn MachineExternal() {
// recover the MTIME and MTIMECMP registers for the current HART
let hart_id = riscv::register::mhartid::read();
let mtimer = unsafe { _riscv_peripheral_aclint_mtimer(hart_id) };
let (mtime, mtimercmp) = (mtimer.mtime, mtimer.mtimecmp_mhartid());
schedule_machine_external(hart_id, mtime, mtimercmp);
// schedule the next machine timer interrupt
schedule_machine_timer(hart_id, mtime, mtimercmp);
}

fn schedule_machine_external(hart_id: usize, mtime: MTIME, mtimercmp: MTIMECMP) {
/// Schedules the next machine timer interrupt for the given HART ID according to the timer queue.
fn schedule_machine_timer(hart_id: usize, mtime: MTIME, mtimercmp: MTIMECMP) {
unsafe { riscv::register::mie::clear_mtimer() }; // disable machine timer interrupts to avoid reentrancy
loop {
let current_tick = mtime.read();
let timer = unsafe { _riscv_peripheral_pop_timer(hart_id, current_tick) };
match timer {
Ok(timer) => {
debug_assert!(timer.expires() <= current_tick);
timer.wake();
}
Err(e) => {
if let Some(next_expires) = e {
debug_assert!(next_expires > current_tick);
mtimercmp.write(next_expires); // schedule next interrupt at next_expires
unsafe { riscv::register::mie::set_mtimer() }; // enable machine timer interrupts again
} else {
mtimercmp.write(u64::MAX); // write max to clear and "disable" the interrupt
}
break;
}
}
let current_tick = mtime.read();
if let Some(next_expires) =
unsafe { _riscv_peripheral_aclint_wake_timers(hart_id, current_tick) }
{
debug_assert!(next_expires > current_tick);
mtimercmp.write(next_expires); // schedule next interrupt at next_expires
unsafe { riscv::register::mie::set_mtimer() }; // enable machine timer interrupts again if necessary
}
}

/// Asynchronous delay implementation for (A)CLINT peripherals.
///
/// # Note
///
/// The asynchronous delay implementation for (A)CLINT peripherals relies on the machine-level timer interrupts.
/// Therefore, it needs to schedule the machine-level timer interrupts via the [`MTIMECMP`] register assigned to the current HART.
/// Thus, the [`Delay`] instance must be created on the same HART that is used to call the asynchronous delay methods.
/// Additionally, the rest of the application must not modify the [`MTIMER`] register assigned to the current HART.
#[derive(Clone)]
pub struct Delay {
mtime: MTIME,
hart_id: usize,
mtimecmp: MTIMECMP,
freq: usize,
mtime: MTIME,
mtimecmp: MTIMECMP,
}

impl Delay {
/// Creates a new `Delay` instance.
#[inline]
pub fn new<H: riscv_pac::HartIdNumber>(mtimer: MTIMER, hart_id: H, freq: usize) -> Self {
Self {
mtime: mtimer.mtime,
hart_id: hart_id.number() as _,
mtimecmp: mtimer.mtimecmp(hart_id),
freq,
}
}

/// Creates a new `Delay` instance for the current HART.
/// This function determines the current HART ID by reading the [`riscv::register::mhartid`] CSR.
///
/// # Note
///
/// This function can only be used in M-mode. For S-mode, use [`Delay::new_mhartid`] instead.
#[inline]
pub fn new_mhartid(mtimer: MTIMER, freq: usize) -> Self {
pub fn new(freq: usize) -> Self {
let hart_id = riscv::register::mhartid::read();
let mtimer = unsafe { _riscv_peripheral_aclint_mtimer(hart_id) };
let (mtime, mtimecmp) = (mtimer.mtime, mtimer.mtimecmp_mhartid());
Self {
mtime: mtimer.mtime,
hart_id,
mtimecmp: mtimer.mtimecmp_mhartid(),
freq,
mtime,
mtimecmp,
}
}

Expand All @@ -130,38 +123,84 @@ impl Delay {
pub fn set_freq(&mut self, freq: usize) {
self.freq = freq;
}
}

/// Returns the `MTIME` register.
impl DelayNs for Delay {
#[inline]
pub const fn get_mtime(&self) -> MTIME {
self.mtime
async fn delay_ns(&mut self, ns: u32) {
let n_ticks = ns as u64 * self.get_freq() as u64 / 1_000_000_000;
DelayAsync::new(self, n_ticks).await;
}

/// Returns the `MTIMECMP` register.
#[inline]
pub const fn get_mtimecmp(&self) -> MTIMECMP {
self.mtimecmp
async fn delay_us(&mut self, us: u32) {
let n_ticks = us as u64 * self.get_freq() as u64 / 1_000_000;
DelayAsync::new(self, n_ticks).await;
}

/// Returns the hart ID.
#[inline]
pub const fn get_hart_id(&self) -> usize {
self.hart_id
async fn delay_ms(&mut self, ms: u32) {
let n_ticks = ms as u64 * self.get_freq() as u64 / 1_000;
DelayAsync::new(self, n_ticks).await;
}
}

/// Timer queue entry.
/// When pushed to the timer queue via the `_riscv_peripheral_aclint_push_timer` function,
/// this entry provides the necessary information to adapt it to the timer queue implementation.
#[derive(Debug)]
pub struct Timer {
hart_id: usize,
freq: usize,
mtime: MTIME,
mtimecmp: MTIMECMP,
expires: u64,
waker: Waker,
}

impl Timer {
/// Creates a new timer queue entry.
#[inline]
pub fn new(expires: u64, waker: Waker) -> Self {
Self { expires, waker }
const fn new(
hart_id: usize,
freq: usize,
mtime: MTIME,
mtimecmp: MTIMECMP,
expires: u64,
waker: Waker,
) -> Self {
Self {
hart_id,
freq,
mtime,
mtimecmp,
expires,
waker,
}
}

/// Returns the HART ID associated with this timer.
#[inline]
pub const fn hart_id(&self) -> usize {
self.hart_id
}

/// Returns the frequency of the [`MTIME`] register associated with this timer.
#[inline]
pub const fn freq(&self) -> usize {
self.freq
}

/// Returns the [`MTIME`] register associated with this timer.
#[inline]
pub const fn mtime(&self) -> MTIME {
self.mtime
}

/// Returns the [`MTIMECMP`] register associated with this timer.
#[inline]
pub const fn mtimecmp(&self) -> MTIMECMP {
self.mtimecmp
}

/// Returns the tick at which the timer expires.
Expand All @@ -170,16 +209,16 @@ impl Timer {
self.expires
}

/// Wakes the waker associated with this timer.
/// Returns the waker associated with this timer.
#[inline]
pub fn wake(&self) {
self.waker.wake_by_ref();
pub fn waker(&self) -> Waker {
self.waker.clone()
}
}

impl PartialEq for Timer {
fn eq(&self, other: &Self) -> bool {
self.expires == other.expires
self.hart_id == other.hart_id && self.freq == other.freq && self.expires == other.expires
}
}

Expand All @@ -197,14 +236,14 @@ impl PartialOrd for Timer {
}
}

struct DelayAsync {
delay: Delay,
struct DelayAsync<'a> {
delay: &'a Delay,
expires: u64,
pushed: bool,
}

impl DelayAsync {
pub fn new(delay: Delay, n_ticks: u64) -> Self {
impl<'a> DelayAsync<'a> {
pub fn new(delay: &'a Delay, n_ticks: u64) -> Self {
let t0 = delay.mtime.read();
let expires = t0.wrapping_add(n_ticks);
Self {
Expand All @@ -215,7 +254,7 @@ impl DelayAsync {
}
}

impl Future for DelayAsync {
impl<'a> Future for DelayAsync<'a> {
type Output = ();

#[inline]
Expand All @@ -224,41 +263,23 @@ impl Future for DelayAsync {
if !self.pushed {
// we only push the timer to the queue the first time we poll
self.pushed = true;
let timer = Timer::new(self.expires, cx.waker().clone());
unsafe {
_riscv_peripheral_push_timer(self.delay.hart_id, &self.delay, timer)
.expect("timer queue is full");
};
// we also need to schedule the interrupt if the timer we just pushed is the earliest one
schedule_machine_external(
let timer = Timer::new(
self.delay.hart_id,
self.delay.freq,
self.delay.mtime,
self.delay.mtimecmp,
self.expires,
cx.waker().clone(),
);
unsafe {
_riscv_peripheral_aclint_push_timer(timer).expect("timer queue is full");
};
// we also need to reschedule the machine timer interrupt
schedule_machine_timer(self.delay.hart_id, self.delay.mtime, self.delay.mtimecmp);
}
Poll::Pending
} else {
Poll::Ready(())
}
}
}

impl DelayNs for Delay {
#[inline]
async fn delay_ns(&mut self, ns: u32) {
let n_ticks = ns as u64 * self.get_freq() as u64 / 1_000_000_000;
DelayAsync::new(self.clone(), n_ticks).await;
}

#[inline]
async fn delay_us(&mut self, us: u32) {
let n_ticks = us as u64 * self.get_freq() as u64 / 1_000_000;
DelayAsync::new(self.clone(), n_ticks).await;
}

#[inline]
async fn delay_ms(&mut self, ms: u32) {
let n_ticks = ms as u64 * self.get_freq() as u64 / 1_000;
DelayAsync::new(self.clone(), n_ticks).await;
}
}
7 changes: 6 additions & 1 deletion riscv-peripheral/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
//! Standard RISC-V peripherals for embedded systems written in Rust
//! Standard RISC-V peripherals for embedded systems written in Rust.
//!
//! ## Features
//!
//! - `aclint-hal-async`: enables the [`hal_async::delay::DelayNs`] implementation for the ACLINT peripheral.
//! This feature relies on external functions that must be provided by the user. See [`hal_async::aclint`] for more information.
#![deny(missing_docs)]
#![no_std]
Expand Down
Loading

0 comments on commit 5b012c0

Please sign in to comment.