Skip to content

Commit

Permalink
feat: add support for preserving characters when decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ForsakenHarmony committed Sep 19, 2024
1 parent 5505565 commit a782c3c
Showing 1 changed file with 104 additions and 13 deletions.
117 changes: 104 additions & 13 deletions percent_encoding/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use core::{fmt, mem, ops, slice, str};
/// /// https://url.spec.whatwg.org/#fragment-percent-encode-set
/// const FRAGMENT: &AsciiSet = &CONTROLS.add(b' ').add(b'"').add(b'<').add(b'>').add(b'`');
/// ```
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AsciiSet {
mask: [Chunk; ASCII_RANGE_LEN / BITS_PER_CHUNK],
}
Expand All @@ -79,7 +79,7 @@ const BITS_PER_CHUNK: usize = 8 * mem::size_of::<Chunk>();

impl AsciiSet {
/// An empty set.
pub const EMPTY: AsciiSet = AsciiSet {
pub const EMPTY: &'static AsciiSet = &AsciiSet {
mask: [0; ASCII_RANGE_LEN / BITS_PER_CHUNK],
};

Expand Down Expand Up @@ -108,7 +108,7 @@ impl AsciiSet {
}

/// Return the union of two sets.
pub const fn union(&self, other: Self) -> Self {
pub const fn union(&self, other: &Self) -> Self {
let mask = [
self.mask[0] | other.mask[0],
self.mask[1] | other.mask[1],
Expand All @@ -128,15 +128,31 @@ impl AsciiSet {
impl ops::Add for AsciiSet {
type Output = Self;

fn add(self, other: Self) -> Self {
fn add(self, other: Self) -> Self::Output {
self.union(&other)
}
}

impl ops::Add for &AsciiSet {
type Output = AsciiSet;

fn add(self, other: Self) -> Self::Output {
self.union(other)
}
}

impl ops::Not for AsciiSet {
type Output = Self;

fn not(self) -> Self {
fn not(self) -> Self::Output {
self.complement()
}
}

impl ops::Not for &AsciiSet {
type Output = AsciiSet;

fn not(self) -> Self::Output {
self.complement()
}
}
Expand Down Expand Up @@ -268,7 +284,7 @@ pub fn percent_encode_byte(byte: u8) -> &'static str {
/// assert_eq!(percent_encode(b"foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F");
/// ```
#[inline]
pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> PercentEncode<'a> {
pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'a AsciiSet) -> PercentEncode<'a> {
PercentEncode {
bytes: input,
ascii_set,
Expand All @@ -287,15 +303,15 @@ pub fn percent_encode<'a>(input: &'a [u8], ascii_set: &'static AsciiSet) -> Perc
/// assert_eq!(utf8_percent_encode("foo bar?", NON_ALPHANUMERIC).to_string(), "foo%20bar%3F");
/// ```
#[inline]
pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'static AsciiSet) -> PercentEncode<'a> {
pub fn utf8_percent_encode<'a>(input: &'a str, ascii_set: &'a AsciiSet) -> PercentEncode<'a> {
percent_encode(input.as_bytes(), ascii_set)
}

/// The return type of [`percent_encode`] and [`utf8_percent_encode`].
#[derive(Clone)]
pub struct PercentEncode<'a> {
bytes: &'a [u8],
ascii_set: &'static AsciiSet,
ascii_set: &'a AsciiSet,
}

impl<'a> Iterator for PercentEncode<'a> {
Expand Down Expand Up @@ -372,6 +388,19 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> {
percent_decode(input.as_bytes())
}

/// Percent-decode the given string preserving the given ascii_set.
///
/// <https://url.spec.whatwg.org/#string-percent-decode>
///
/// See [`percent_decode`] regarding the return type.
#[inline]
pub fn percent_decode_str_with_set<'a>(
input: &'a str,
ascii_set: &'a AsciiSet,
) -> PercentDecode<'a> {
percent_decode_with_set(input.as_bytes(), ascii_set)
}

/// Percent-decode the given bytes.
///
/// <https://url.spec.whatwg.org/#percent-decode>
Expand All @@ -394,13 +423,44 @@ pub fn percent_decode_str(input: &str) -> PercentDecode<'_> {
pub fn percent_decode(input: &[u8]) -> PercentDecode<'_> {
PercentDecode {
bytes: input.iter(),
ascii_set: None,
}
}

/// Percent-decode the given bytes preserving the given ascii_set.
///
/// <https://url.spec.whatwg.org/#percent-decode>
///
/// Any sequence of `%` followed by two hexadecimal digits expect for the given [AsciiSet] is decoded.
/// The return type:
///
/// * Implements `Into<Cow<u8>>` borrowing `input` when it contains no percent-encoded sequence,
/// * Implements `Iterator<Item = u8>` and therefore has a `.collect::<Vec<u8>>()` method,
/// * Has `decode_utf8()` and `decode_utf8_lossy()` methods.
///
/// # Examples
///
/// ```
/// use percent_encoding::{percent_decode_with_set, NON_ALPHANUMERIC};
///
/// assert_eq!(percent_decode_with_set(b"%66oo%20bar%3f", &!NON_ALPHANUMERIC).decode_utf8().unwrap(), "%66oo bar?");
/// ```
#[inline]
pub fn percent_decode_with_set<'a>(
input: &'a [u8],
ascii_set: &'a AsciiSet,
) -> PercentDecode<'a> {
PercentDecode {
bytes: input.iter(),
ascii_set: Some(ascii_set),
}
}

/// The return type of [`percent_decode`].
#[derive(Clone, Debug)]
pub struct PercentDecode<'a> {
bytes: slice::Iter<'a, u8>,
ascii_set: Option<&'a AsciiSet>,
}

fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option<u8> {
Expand All @@ -411,13 +471,35 @@ fn after_percent_sign(iter: &mut slice::Iter<'_, u8>) -> Option<u8> {
Some(h as u8 * 0x10 + l as u8)
}

fn after_percent_sign_lookahead<'a>(
iter: &mut slice::Iter<'a, u8>,
) -> Option<(u8, slice::Iter<'a, u8>)> {
let mut cloned_iter = iter.clone();
let h = char::from(*cloned_iter.next()?).to_digit(16)?;
let l = char::from(*cloned_iter.next()?).to_digit(16)?;
Some((h as u8 * 0x10 + l as u8, cloned_iter))
}

impl<'a> Iterator for PercentDecode<'a> {
type Item = u8;

fn next(&mut self) -> Option<u8> {
self.bytes.next().map(|&byte| {
if byte == b'%' {
after_percent_sign(&mut self.bytes).unwrap_or(byte)
if byte != b'%' {
return byte;
}

let Some((decoded_byte, iter)) = after_percent_sign_lookahead(&mut self.bytes) else {
return byte;
};

let should_decode = self
.ascii_set
.map_or(true, |ascii_set| !ascii_set.contains(decoded_byte));

if should_decode {
self.bytes = iter;
decoded_byte
} else {
byte
}
Expand Down Expand Up @@ -447,11 +529,20 @@ impl<'a> PercentDecode<'a> {
let mut bytes_iter = self.bytes.clone();
while bytes_iter.any(|&b| b == b'%') {
if let Some(decoded_byte) = after_percent_sign(&mut bytes_iter) {
if let Some(ascii_set) = self.ascii_set {
if ascii_set.contains(decoded_byte) {
continue;
}
}

let initial_bytes = self.bytes.as_slice();
let unchanged_bytes_len = initial_bytes.len() - bytes_iter.len() - 3;
let mut decoded = initial_bytes[..unchanged_bytes_len].to_owned();
decoded.push(decoded_byte);
decoded.extend(PercentDecode { bytes: bytes_iter });
decoded.extend(PercentDecode {
bytes: bytes_iter,
ascii_set: self.ascii_set,
});
return Some(decoded);
}
}
Expand Down Expand Up @@ -542,8 +633,8 @@ mod tests {
/// useful for defining sets in a modular way.
#[test]
fn union() {
const A: AsciiSet = AsciiSet::EMPTY.add(b'A');
const B: AsciiSet = AsciiSet::EMPTY.add(b'B');
const A: &AsciiSet = &AsciiSet::EMPTY.add(b'A');
const B: &AsciiSet = &AsciiSet::EMPTY.add(b'B');
const UNION: AsciiSet = A.union(B);
const EXPECTED: AsciiSet = AsciiSet::EMPTY.add(b'A').add(b'B');
assert_eq!(UNION, EXPECTED);
Expand Down

0 comments on commit a782c3c

Please sign in to comment.