diff --git a/Cargo.toml b/Cargo.toml index 74e998711..a48dc7637 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ async-trait = "0.1.36" femme = { version = "2.0.1", optional = true } futures-util = "0.3.5" http-client = { version = "6.0.0", default-features = false } -http-types = "2.2.1" +http-types = "2.5.0" kv-log-macro = "1.0.4" log = { version = "0.4.8", features = ["std"] } pin-project-lite = "0.1.7" diff --git a/examples/fib.rs b/examples/fib.rs index 39e122b2b..2b0451ecb 100644 --- a/examples/fib.rs +++ b/examples/fib.rs @@ -10,7 +10,7 @@ fn fib(n: usize) -> usize { async fn fibsum(req: Request<()>) -> tide::Result { use std::time::Instant; - let n: usize = req.param("n").unwrap_or(0); + let n: usize = req.param("n")?.parse().unwrap_or(0); // Start a stopwatch let start = Instant::now(); // Compute the nth number in the fibonacci sequence diff --git a/examples/upload.rs b/examples/upload.rs index 5ac0e1d06..05b16765d 100644 --- a/examples/upload.rs +++ b/examples/upload.rs @@ -36,7 +36,7 @@ async fn main() -> Result<(), IoError> { app.at(":file") .put(|req: Request| async move { - let path: String = req.param("file")?; + let path = req.param("file")?; let fs_path = req.state().path().join(path); let file = OpenOptions::new() @@ -55,7 +55,7 @@ async fn main() -> Result<(), IoError> { Ok(json!({ "bytes": bytes_written })) }) .get(|req: Request| async move { - let path: String = req.param("file")?; + let path = req.param("file")?; let fs_path = req.state().path().join(path); if let Ok(body) = Body::from_file(fs_path).await { diff --git a/src/lib.rs b/src/lib.rs index 64e955664..f47041964 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,7 +90,7 @@ pub mod sessions; pub use endpoint::Endpoint; pub use middleware::{Middleware, Next}; pub use redirect::Redirect; -pub use request::{ParamError, Request}; +pub use request::Request; pub use response::Response; pub use response_builder::ResponseBuilder; pub use route::Route; diff --git a/src/request.rs b/src/request.rs index be9b566bf..458816ddc 100644 --- a/src/request.rs +++ b/src/request.rs @@ -4,10 +4,10 @@ use route_recognizer::Params; use std::ops::Index; use std::pin::Pin; -use std::{fmt, str::FromStr}; use crate::cookies::CookieData; use crate::http::cookies::Cookie; +use crate::http::format_err; use crate::http::headers::{self, HeaderName, HeaderValues, ToHeaderValues}; use crate::http::{self, Body, Method, Mime, StatusCode, Url, Version}; use crate::Response; @@ -29,23 +29,6 @@ pin_project_lite::pin_project! { } } -#[derive(Debug)] -pub enum ParamError { - NotFound(String), - ParsingError(E), -} - -impl fmt::Display for ParamError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ParamError::NotFound(name) => write!(f, "Param \"{}\" not found!", name), - ParamError::ParsingError(err) => write!(f, "Param failed to parse: {}", err), - } - } -} - -impl std::error::Error for ParamError {} - impl Request { /// Create a new `Request`. pub(crate) fn new(state: State, req: http_types::Request, route_params: Vec) -> Self { @@ -287,10 +270,7 @@ impl Request { /// /// # Errors /// - /// Yields a `ParamError::ParsingError` if the parameter was found but failed to parse as an - /// instance of type `T`. - /// - /// Yields a `ParamError::NotFound` if `key` is not a parameter for the route. + /// An error is returned if `key` is not a valid parameter for the route. /// /// # Examples /// @@ -301,7 +281,7 @@ impl Request { /// use tide::{Request, Result}; /// /// async fn greet(req: Request<()>) -> Result { - /// let name = req.param("name").unwrap_or("world".to_owned()); + /// let name = req.param("name").unwrap_or("world"); /// Ok(format!("Hello, {}!", name)) /// } /// @@ -312,13 +292,12 @@ impl Request { /// # /// # Ok(()) })} /// ``` - pub fn param(&self, key: &str) -> Result> { + pub fn param(&self, key: &str) -> crate::Result<&str> { self.route_params .iter() .rev() .find_map(|params| params.find(key)) - .ok_or_else(|| ParamError::NotFound(key.to_string())) - .and_then(|param| param.parse().map_err(ParamError::ParsingError)) + .ok_or_else(|| format_err!("Param \"{}\" not found", key.to_string())) } /// Parse the URL query component into a struct, using [serde_qs](https://docs.rs/serde_qs). To diff --git a/tests/params.rs b/tests/params.rs index 6dea9c632..0d551b515 100644 --- a/tests/params.rs +++ b/tests/params.rs @@ -1,69 +1,39 @@ use http_types::{self, Method, Url}; -use tide::{self, Request, Response, Result, StatusCode}; +use tide::{self, Request, Response, Result}; #[async_std::test] -async fn test_param_invalid_type() { - async fn get_by_id(req: Request<()>) -> Result { - assert_eq!( - req.param::("id").unwrap_err().to_string(), - "Param failed to parse: invalid digit found in string" - ); - let _ = req.param::("id")?; - Result::Ok(Response::new(StatusCode::Ok)) - } - let mut server = tide::new(); - server.at("/by_id/:id").get(get_by_id); - - let req = http_types::Request::new( - Method::Get, - Url::parse("http://example.com/by_id/wrong").unwrap(), - ); - let res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::InternalServerError); -} - -#[async_std::test] -async fn test_missing_param() { +async fn test_missing_param() -> tide::Result<()> { async fn greet(req: Request<()>) -> Result { - assert_eq!( - req.param::("name").unwrap_err().to_string(), - "Param \"name\" not found!" - ); - let _: String = req.param("name")?; - Result::Ok(Response::new(StatusCode::Ok)) + assert_eq!(req.param("name")?, "Param \"name\" not found"); + Ok(Response::new(200)) } + let mut server = tide::new(); server.at("/").get(greet); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/").unwrap()); - let res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!(res.status(), StatusCode::InternalServerError); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/")?); + let res: http_types::Response = server.respond(req).await?; + assert_eq!(res.status(), 500); + Ok(()) } #[async_std::test] -async fn hello_world_parametrized() { - async fn greet(req: tide::Request<()>) -> Result { - let name = req.param("name").unwrap_or_else(|_| "nori".to_owned()); - let mut response = tide::Response::new(StatusCode::Ok); - response.set_body(format!("{} says hello", name)); - Ok(response) +async fn hello_world_parametrized() -> Result<()> { + async fn greet(req: tide::Request<()>) -> Result> { + let body = format!("{} says hello", req.param("name").unwrap_or("nori")); + Ok(Response::builder(200).body(body)) } let mut server = tide::new(); server.at("/").get(greet); server.at("/:name").get(greet); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/").unwrap()); - let mut res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!( - res.body_string().await.unwrap(), - "nori says hello".to_string() - ); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/")?); + let mut res: http_types::Response = server.respond(req).await?; + assert_eq!(res.body_string().await?, "nori says hello"); - let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/iron").unwrap()); - let mut res: http_types::Response = server.respond(req).await.unwrap(); - assert_eq!( - res.body_string().await.unwrap(), - "iron says hello".to_string() - ); + let req = http_types::Request::new(Method::Get, Url::parse("http://example.com/iron")?); + let mut res: http_types::Response = server.respond(req).await?; + assert_eq!(res.body_string().await?, "iron says hello"); + Ok(()) } diff --git a/tests/wildcard.rs b/tests/wildcard.rs index c17806440..6cca1c0db 100644 --- a/tests/wildcard.rs +++ b/tests/wildcard.rs @@ -1,27 +1,34 @@ mod test_utils; use test_utils::ServerTestingExt; -use tide::{Request, StatusCode}; +use tide::{Error, Request, StatusCode}; + async fn add_one(req: Request<()>) -> Result { - match req.param::("num") { - Ok(num) => Ok((num + 1).to_string()), - Err(err) => Err(tide::Error::new(StatusCode::BadRequest, err)), - } + let num: i64 = req + .param("num")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; + Ok((num + 1).to_string()) } async fn add_two(req: Request<()>) -> Result { - let one = req - .param::("one") - .map_err(|err| tide::Error::new(StatusCode::BadRequest, err))?; - let two = req - .param::("two") - .map_err(|err| tide::Error::new(StatusCode::BadRequest, err))?; + let one: i64 = req + .param("one")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; + let two: i64 = req + .param("two")? + .parse() + .map_err(|err| Error::new(StatusCode::BadRequest, err))?; Ok((one + two).to_string()) } async fn echo_path(req: Request<()>) -> Result { - match req.param::("path") { - Ok(path) => Ok(path), - Err(err) => Err(tide::Error::new(StatusCode::BadRequest, err)), + match req.param("path") { + Ok(path) => Ok(path.into()), + Err(mut err) => { + err.set_status(StatusCode::BadRequest); + Err(err) + } } } @@ -124,7 +131,7 @@ async fn nameless_internal_wildcard() -> tide::Result<()> { async fn nameless_internal_wildcard2() -> tide::Result<()> { let mut app = tide::new(); app.at("/echo/:/:path").get(|req: Request<()>| async move { - assert_eq!(req.param::("path")?, "two"); + assert_eq!(req.param("path")?, "two"); Ok("") });