use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use async_trait::async_trait;
use chrono::{Duration, Local};
use log::{debug, error, trace};
use osproto::identity as protocol;
use reqwest::{Client, IntoUrl, Method, RequestBuilder, Response, Url};
use tokio::sync::RwLock;
use super::{request, AuthType, EndpointFilters, Error, ErrorKind, InterfaceType, ValidInterfaces};
pub use osproto::identity::IdOrName;
const MISSING_SUBJECT_HEADER: &str = "Missing X-Subject-Token header";
const INVALID_SUBJECT_HEADER: &str = "Invalid X-Subject-Token header";
const TOKEN_MIN_VALIDITY: i64 = 10;
#[derive(Debug)]
pub enum Scope {
Project {
project: IdOrName,
domain: Option<IdOrName>,
},
}
#[derive(Clone)]
struct Token {
value: String,
body: protocol::Token,
}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut hasher = DefaultHasher::new();
self.value.hash(&mut hasher);
write!(
f,
"Token {{ value: hash({}), body: {:?} }}",
hasher.finish(),
self.body
)
}
}
pub trait Identity {
fn auth_url(&self) -> &Url;
}
#[derive(Debug)]
pub struct Password {
client: Client,
auth_url: Url,
body: protocol::AuthRoot,
token_endpoint: String,
cached_token: RwLock<Option<Token>>,
filters: EndpointFilters,
}
impl Clone for Password {
fn clone(&self) -> Password {
Password {
client: self.client.clone(),
auth_url: self.auth_url.clone(),
body: self.body.clone(),
token_endpoint: self.token_endpoint.clone(),
cached_token: RwLock::new(None),
filters: self.filters.clone(),
}
}
}
impl Identity for Password {
fn auth_url(&self) -> &Url {
&self.auth_url
}
}
impl Password {
pub fn new<U, S1, S2, S3>(
auth_url: U,
user_name: S1,
password: S2,
user_domain_name: S3,
) -> Result<Password, Error>
where
U: IntoUrl,
S1: Into<String>,
S2: Into<String>,
S3: Into<String>,
{
Password::new_with_client(
auth_url,
Client::new(),
user_name,
password,
user_domain_name,
)
}
pub fn new_with_client<U, S1, S2, S3>(
auth_url: U,
client: Client,
user_name: S1,
password: S2,
user_domain_name: S3,
) -> Result<Password, Error>
where
U: IntoUrl,
S1: Into<String>,
S2: Into<String>,
S3: Into<String>,
{
let mut url = auth_url.into_url()?;
let _ = url
.path_segments_mut()
.map_err(|_| Error::new(ErrorKind::InvalidConfig, "Invalid auth_url: wrong schema?"))?
.pop_if_empty();
let token_endpoint = if url.as_str().ends_with("/v3") {
format!("{}/auth/tokens", url)
} else {
format!("{}/v3/auth/tokens", url)
};
let pw = protocol::UserAndPassword {
user: protocol::IdOrName::Name(user_name.into()),
password: password.into(),
domain: Some(protocol::IdOrName::Name(user_domain_name.into())),
};
let body = protocol::AuthRoot {
auth: protocol::Auth {
identity: protocol::Identity::Password(pw),
scope: None,
},
};
Ok(Password {
client,
auth_url: url,
body,
token_endpoint,
cached_token: RwLock::new(None),
filters: EndpointFilters::default(),
})
}
#[inline]
pub fn endpoint_filters(&self) -> &EndpointFilters {
&self.filters
}
#[inline]
pub fn endpoint_filters_mut(&mut self) -> &mut EndpointFilters {
&mut self.filters
}
pub fn set_default_endpoint_interface(&mut self, endpoint_interface: InterfaceType) {
self.filters.interfaces = ValidInterfaces::one(endpoint_interface);
}
#[inline]
pub fn set_endpoint_filters(&mut self, filters: EndpointFilters) {
self.filters = filters;
}
#[deprecated(since = "0.3.0", note = "Use set_filters or filters_mut")]
pub fn set_region<S>(&mut self, region: S)
where
S: Into<String>,
{
self.filters.region = Some(region.into());
}
#[inline]
pub fn set_project_scope(&mut self, project: IdOrName, domain: impl Into<Option<IdOrName>>) {
self.set_scope(Scope::Project {
project,
domain: domain.into(),
});
}
pub fn set_scope(&mut self, scope: Scope) {
self.body.auth.scope = Some(match scope {
Scope::Project { project, domain } => {
protocol::Scope::Project(protocol::Project { project, domain })
}
});
}
#[inline]
pub fn with_default_endpoint_interface(mut self, endpoint_interface: InterfaceType) -> Self {
self.set_default_endpoint_interface(endpoint_interface);
self
}
#[inline]
pub fn with_endpoint_filters(mut self, filters: EndpointFilters) -> Self {
self.filters = filters;
self
}
#[inline]
pub fn with_project_scope(
mut self,
project: IdOrName,
domain: impl Into<Option<IdOrName>>,
) -> Password {
self.set_project_scope(project, domain);
self
}
#[inline]
pub fn with_region<S>(mut self, region: S) -> Self
where
S: Into<String>,
{
self.filters.region = Some(region.into());
self
}
#[inline]
pub fn with_scope(mut self, scope: Scope) -> Self {
self.set_scope(scope);
self
}
async fn do_refresh(&self, force: bool) -> Result<(), Error> {
if !force && token_alive(&self.cached_token.read().await) {
return Ok(());
}
let mut lock = self.cached_token.write().await;
if token_alive(&lock) {
return Ok(());
}
let resp = self
.client
.post(&self.token_endpoint)
.json(&self.body)
.send()
.await?;
*lock = Some(token_from_response(request::check(resp).await?).await?);
Ok(())
}
#[inline]
pub fn user(&self) -> &IdOrName {
match self.body.auth.identity {
protocol::Identity::Password(ref pw) => &pw.user,
_ => unreachable!(),
}
}
#[inline]
pub fn project(&self) -> Option<&IdOrName> {
match self.body.auth.scope {
Some(protocol::Scope::Project(ref prj)) => Some(&prj.project),
_ => None,
}
}
#[inline]
async fn get_token(&self) -> Result<String, Error> {
self.do_refresh(false).await?;
Ok(self
.cached_token
.read()
.await
.as_ref()
.unwrap()
.value
.clone())
}
}
#[inline]
fn token_alive(token: &impl Deref<Target = Option<Token>>) -> bool {
if let Some(value) = token.deref() {
let validity_time_left = value.body.expires_at.signed_duration_since(Local::now());
trace!("Token is valid for {:?}", validity_time_left);
validity_time_left > Duration::minutes(TOKEN_MIN_VALIDITY)
} else {
false
}
}
#[async_trait]
impl AuthType for Password {
fn default_filters(&self) -> Option<&EndpointFilters> {
Some(&self.filters)
}
async fn request(&self, method: Method, url: Url) -> Result<RequestBuilder, Error> {
let token = self.get_token().await?;
Ok(self
.client
.request(method, url)
.header("x-auth-token", token))
}
async fn get_endpoint(
&self,
service_type: String,
filters: EndpointFilters,
) -> Result<Url, Error> {
let real_filters = filters.with_defaults(&self.filters);
debug!(
"Requesting a catalog endpoint for service '{}', filters {:?}",
service_type, real_filters
);
self.do_refresh(false).await?;
let lock = self.cached_token.read().await;
real_filters.find_in_catalog(&lock.as_ref().unwrap().body.catalog, &service_type)
}
async fn refresh(&self) -> Result<(), Error> {
self.do_refresh(true).await
}
}
async fn token_from_response(resp: Response) -> Result<Token, Error> {
let value = match resp.headers().get("x-subject-token") {
Some(hdr) => match hdr.to_str() {
Ok(s) => Ok(s.to_string()),
Err(e) => {
error!(
"Invalid X-Subject-Token {:?} received from {}: {}",
hdr,
resp.url(),
e
);
Err(Error::new(
ErrorKind::InvalidResponse,
INVALID_SUBJECT_HEADER,
))
}
},
None => {
error!("No X-Subject-Token header received from {}", resp.url());
Err(Error::new(
ErrorKind::InvalidResponse,
MISSING_SUBJECT_HEADER,
))
}
}?;
let root = resp.json::<protocol::TokenRoot>().await?;
debug!("Received a token expiring at {}", root.token.expires_at);
trace!("Received catalog: {:?}", root.token.catalog);
Ok(Token {
value,
body: root.token,
})
}
#[cfg(test)]
pub mod test {
#![allow(unused_results)]
use super::{IdOrName, Identity, Password};
#[test]
fn test_identity_new() {
let id = Password::new("http://127.0.0.1:8080/", "admin", "pa$$w0rd", "Default").unwrap();
let e = id.auth_url();
assert_eq!(e.scheme(), "http");
assert_eq!(e.host_str().unwrap(), "127.0.0.1");
assert_eq!(e.port().unwrap(), 8080u16);
assert_eq!(e.path(), "/");
assert_eq!(id.user(), &IdOrName::Name("admin".to_string()));
}
#[test]
fn test_identity_new_invalid() {
Password::new("http://127.0.0.1 8080/", "admin", "pa$$w0rd", "Default")
.err()
.unwrap();
}
#[test]
fn test_identity_create() {
let id = Password::new(
"http://127.0.0.1:8080/identity",
"user",
"pa$$w0rd",
"example.com",
)
.unwrap()
.with_project_scope(
IdOrName::Name("cool project".to_string()),
IdOrName::Name("example.com".to_string()),
);
assert_eq!(id.auth_url().to_string(), "http://127.0.0.1:8080/identity");
assert_eq!(id.user(), &IdOrName::Name("user".to_string()));
assert_eq!(
id.project(),
Some(&IdOrName::Name("cool project".to_string()))
);
assert_eq!(
&id.token_endpoint,
"http://127.0.0.1:8080/identity/v3/auth/tokens"
);
assert_eq!(id.endpoint_filters().region, None);
}
#[test]
fn test_token_endpoint_with_trailing_slash() {
let id = Password::new(
"http://127.0.0.1:8080/identity/",
"user",
"pa$$w0rd",
"example.com",
)
.unwrap()
.with_project_scope(
IdOrName::Name("cool project".to_string()),
IdOrName::Name("example.com".to_string()),
);
assert_eq!(id.auth_url().to_string(), "http://127.0.0.1:8080/identity");
assert_eq!(id.user(), &IdOrName::Name("user".to_string()));
assert_eq!(
id.project(),
Some(&IdOrName::Name("cool project".to_string()))
);
assert_eq!(
&id.token_endpoint,
"http://127.0.0.1:8080/identity/v3/auth/tokens"
);
assert_eq!(id.endpoint_filters().region, None);
}
#[test]
fn test_token_endpoint_with_v3() {
let id = Password::new(
"http://127.0.0.1:8080/identity/v3",
"user",
"pa$$w0rd",
"example.com",
)
.unwrap()
.with_project_scope(
IdOrName::Name("cool project".to_string()),
IdOrName::Name("example.com".to_string()),
);
assert_eq!(
id.auth_url().to_string(),
"http://127.0.0.1:8080/identity/v3"
);
assert_eq!(id.user(), &IdOrName::Name("user".to_string()));
assert_eq!(
id.project(),
Some(&IdOrName::Name("cool project".to_string()))
);
assert_eq!(
&id.token_endpoint,
"http://127.0.0.1:8080/identity/v3/auth/tokens"
);
assert_eq!(id.endpoint_filters().region, None);
}
#[test]
fn test_token_endpoint_with_trailing_slash_v3() {
let id = Password::new(
"http://127.0.0.1:8080/identity/v3/",
"user",
"pa$$w0rd",
"example.com",
)
.unwrap()
.with_project_scope(
IdOrName::Name("cool project".to_string()),
IdOrName::Name("example.com".to_string()),
);
assert_eq!(
id.auth_url().to_string(),
"http://127.0.0.1:8080/identity/v3"
);
assert_eq!(id.user(), &IdOrName::Name("user".to_string()));
assert_eq!(
id.project(),
Some(&IdOrName::Name("cool project".to_string()))
);
assert_eq!(
&id.token_endpoint,
"http://127.0.0.1:8080/identity/v3/auth/tokens"
);
assert_eq!(id.endpoint_filters().region, None);
}
}