use std::cell::RefCell;
use std::io;
use std::pin::Pin;
use std::result;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures::executor::{self, BlockingStream};
use futures::stream::Stream;
use futures::Future;
use pin_project::pin_project;
use reqwest::{Body, RequestBuilder, Response};
use reqwest::{Method, Url};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::runtime::{Builder as RuntimeBuilder, Runtime};
use super::request;
use super::services::ServiceType;
use super::{ApiVersion, AuthType, EndpointFilters, Error, InterfaceType, Session};
pub type Result<T> = result::Result<T, Error>;
pub type SyncStreamItem = result::Result<Bytes, ::reqwest::Error>;
#[derive(Debug)]
pub struct SyncStream<'s, S, E = ::reqwest::Error>
where
S: Stream<Item = result::Result<Bytes, E>> + Unpin,
{
session: &'s SyncSession,
inner: BlockingStream<S>,
current: io::Cursor<Bytes>,
}
#[pin_project]
#[derive(Debug, Clone, Default)]
pub struct SyncBody<R> {
reader: R,
}
#[derive(Debug)]
pub struct SyncSession {
inner: Session,
runtime: RefCell<Runtime>,
}
impl From<SyncSession> for Session {
fn from(value: SyncSession) -> Session {
value.inner
}
}
impl From<Session> for SyncSession {
fn from(value: Session) -> SyncSession {
SyncSession::new(value)
}
}
impl Clone for SyncSession {
fn clone(&self) -> SyncSession {
SyncSession::new(self.inner.clone())
}
}
impl SyncSession {
pub fn new(session: Session) -> SyncSession {
SyncSession {
inner: session,
runtime: RefCell::new(
RuntimeBuilder::new()
.basic_scheduler()
.enable_io()
.build()
.expect("Could not create a runtime"),
),
}
}
#[inline]
pub fn from_config<S: AsRef<str>>(cloud_name: S) -> Result<SyncSession> {
Ok(Self::new(Session::from_config(cloud_name)?))
}
#[inline]
pub fn from_env() -> Result<SyncSession> {
Ok(Self::new(Session::from_env()?))
}
#[inline]
pub fn auth_type(&self) -> &dyn AuthType {
self.inner.auth_type()
}
#[inline]
pub fn endpoint_filters(&self) -> &EndpointFilters {
self.inner.endpoint_filters()
}
#[inline]
pub fn endpoint_filters_mut(&mut self) -> &mut EndpointFilters {
self.inner.endpoint_filters_mut()
}
#[inline]
pub fn refresh(&mut self) -> Result<()> {
let fut = self.inner.refresh();
self.runtime.borrow_mut().block_on(fut)
}
#[inline]
pub fn session(&self) -> &Session {
&self.inner
}
#[inline]
pub fn set_auth_type<Auth: AuthType + 'static>(&mut self, auth_type: Auth) {
self.inner.set_auth_type(auth_type);
}
pub fn set_endpoint_interface(&mut self, endpoint_interface: InterfaceType) {
self.inner.set_endpoint_interface(endpoint_interface);
}
#[inline]
pub fn with_auth_type<Auth: AuthType + 'static>(mut self, auth_method: Auth) -> SyncSession {
self.set_auth_type(auth_method);
self
}
#[inline]
pub fn with_endpoint_filters(mut self, endpoint_filters: EndpointFilters) -> SyncSession {
*self.endpoint_filters_mut() = endpoint_filters;
self
}
#[inline]
pub fn with_endpoint_interface(mut self, endpoint_interface: InterfaceType) -> SyncSession {
self.set_endpoint_interface(endpoint_interface);
self
}
#[inline]
pub fn get_api_versions<Srv>(&self, service: Srv) -> Result<Option<(ApiVersion, ApiVersion)>>
where
Srv: ServiceType + Send,
{
self.block_on(self.inner.get_api_versions(service))
}
#[inline]
pub fn get_endpoint<Srv, I>(&self, service: Srv, path: I) -> Result<Url>
where
Srv: ServiceType + Send,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
{
self.block_on(self.inner.get_endpoint(service, path))
}
#[inline]
pub fn get_major_version<Srv>(&self, service: Srv) -> Result<Option<ApiVersion>>
where
Srv: ServiceType + Send,
{
self.block_on(self.inner.get_major_version(service))
}
#[inline]
pub fn pick_api_version<Srv, I>(&self, service: Srv, versions: I) -> Result<Option<ApiVersion>>
where
Srv: ServiceType + Send,
I: IntoIterator<Item = ApiVersion>,
I::IntoIter: Send,
{
self.block_on(self.inner.pick_api_version(service, versions))
}
#[inline]
pub fn supports_api_version<Srv: ServiceType + Send>(
&self,
service: Srv,
version: ApiVersion,
) -> Result<bool> {
self.block_on(self.inner.supports_api_version(service, version))
}
pub fn request<Srv, I>(
&self,
service: Srv,
method: Method,
path: I,
api_version: Option<ApiVersion>,
) -> Result<RequestBuilder>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
{
self.block_on(self.inner.request(service, method, path, api_version))
}
#[inline]
pub fn get<Srv, I>(
&self,
service: Srv,
path: I,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
{
self.send_checked(self.request(service, Method::GET, path, api_version)?)
}
#[inline]
pub fn get_json<Srv, I, T>(
&self,
service: Srv,
path: I,
api_version: Option<ApiVersion>,
) -> Result<T>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
T: DeserializeOwned + Send,
{
self.fetch_json(self.request(service, Method::GET, path, api_version)?)
}
#[inline]
pub fn get_json_query<Srv, I, Q, T>(
&self,
service: Srv,
path: I,
query: Q,
api_version: Option<ApiVersion>,
) -> Result<T>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
Q: Serialize + Send,
T: DeserializeOwned + Send,
{
self.fetch_json(
self.request(service, Method::GET, path, api_version)?
.query(&query),
)
}
#[inline]
pub fn get_query<Srv, I, Q>(
&self,
service: Srv,
path: I,
query: Q,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
Q: Serialize + Send,
{
self.send_checked(
self.request(service, Method::GET, path, api_version)?
.query(&query),
)
}
#[inline]
pub fn download(&self, response: Response) -> SyncStream<impl Stream<Item = SyncStreamItem>> {
SyncStream::new(self, response.bytes_stream())
}
#[inline]
pub fn post<Srv, I, T>(
&self,
service: Srv,
path: I,
body: T,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
T: Serialize + Send,
{
self.send_checked(
self.request(service, Method::POST, path, api_version)?
.json(&body),
)
}
#[inline]
pub fn post_json<Srv, I, T, R>(
&self,
service: Srv,
path: I,
body: T,
api_version: Option<ApiVersion>,
) -> Result<R>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
T: Serialize + Send,
R: DeserializeOwned + Send,
{
self.fetch_json(
self.request(service, Method::POST, path, api_version)?
.json(&body),
)
}
#[inline]
pub fn put<Srv, I, T>(
&self,
service: Srv,
path: I,
body: T,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
T: Serialize + Send,
{
self.send_checked(
self.request(service, Method::PUT, path, api_version)?
.json(&body),
)
}
#[inline]
pub fn put_empty<Srv, I>(
&self,
service: Srv,
path: I,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
{
self.send_checked(self.request(service, Method::PUT, path, api_version)?)
}
#[inline]
pub fn put_json<Srv, I, T, R>(
&self,
service: Srv,
path: I,
body: T,
api_version: Option<ApiVersion>,
) -> Result<R>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
T: Serialize + Send,
R: DeserializeOwned + Send,
{
self.fetch_json(
self.request(service, Method::PUT, path, api_version)?
.json(&body),
)
}
#[inline]
pub fn delete<Srv, I>(
&self,
service: Srv,
path: I,
api_version: Option<ApiVersion>,
) -> Result<Response>
where
Srv: ServiceType + Send + Clone,
I: IntoIterator,
I::Item: AsRef<str>,
I::IntoIter: Send,
{
self.send_checked(self.request(service, Method::DELETE, path, api_version)?)
}
#[inline]
pub fn fetch_json<T>(&self, builder: RequestBuilder) -> Result<T>
where
T: DeserializeOwned + Send,
{
self.block_on(async { request::to_json(builder.send().await?).await })
}
#[inline]
pub fn send_checked(&self, builder: RequestBuilder) -> Result<Response> {
self.block_on(async { request::check(builder.send().await?).await })
}
#[inline]
fn block_on<F>(&self, f: F) -> F::Output
where
F: Future,
{
self.runtime.borrow_mut().block_on(f)
}
}
impl<'s, S, E> SyncStream<'s, S, E>
where
S: Stream<Item = result::Result<Bytes, E>> + Unpin,
{
fn new(session: &'s SyncSession, inner: S) -> SyncStream<S, E> {
SyncStream {
session,
inner: executor::block_on_stream(inner),
current: io::Cursor::default(),
}
}
}
impl<'s, S, E> io::Read for SyncStream<'s, S, E>
where
S: Stream<Item = result::Result<Bytes, E>> + Unpin,
E: Into<Box<dyn ::std::error::Error + Send + Sync + 'static>>,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let existing = self.current.read(buf)?;
if existing > 0 {
return Ok(existing);
}
if let Some(next) = self.inner.next() {
self.current =
io::Cursor::new(next.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?);
} else {
return Ok(0);
}
}
}
}
impl<R> SyncBody<R> {
#[inline]
pub fn new(body: R) -> SyncBody<R> {
SyncBody { reader: body }
}
}
impl<R> Stream for SyncBody<R>
where
R: io::Read,
{
type Item = ::std::result::Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut buffer = vec![0; 16384];
let reader = self.project().reader;
let size = reader.read(&mut buffer)?;
Poll::Ready(if size > 0 {
buffer.truncate(size);
Some(Ok(buffer.into()))
} else {
None
})
}
}
impl<R> From<SyncBody<R>> for Body
where
R: io::Read + Send + Sync + 'static,
{
fn from(value: SyncBody<R>) -> Body {
Body::wrap_stream(value)
}
}
#[cfg(test)]
mod test {
use std::io::{Cursor, Read};
use bytes::Bytes;
use futures::stream;
use reqwest::{Body, Error as HttpError};
use super::super::session::test;
use super::super::{ApiVersion, Error};
use super::{SyncBody, SyncSession, SyncStream};
fn new_simple_sync_session(url: &str) -> SyncSession {
SyncSession::new(test::new_simple_session(url))
}
fn new_sync_session(url: &str) -> SyncSession {
SyncSession::new(test::new_session(url, test::fake_service_info()))
}
#[test]
fn test_get_api_versions_absent() {
let s = new_simple_sync_session(test::URL);
let vers = s.get_api_versions(test::FAKE).unwrap();
assert!(vers.is_none());
}
#[test]
fn test_get_api_versions_present() {
let s = new_sync_session(test::URL);
let (min, max) = s.get_api_versions(test::FAKE).unwrap().unwrap();
assert_eq!(min, test::MIN_VERSION);
assert_eq!(max, test::MAX_VERSION);
}
#[test]
fn test_get_endpoint() {
let s = new_simple_sync_session(test::URL);
let ep = s.get_endpoint(test::FAKE, &[""]).unwrap();
assert_eq!(&ep.to_string(), test::URL);
}
#[test]
fn test_get_endpoint_slice() {
let s = new_simple_sync_session(test::URL);
let ep = s.get_endpoint(test::FAKE, &["v2", "servers"]).unwrap();
assert_eq!(&ep.to_string(), test::URL_WITH_SUFFIX);
}
#[test]
fn test_get_endpoint_vec() {
let s = new_simple_sync_session(test::URL);
let ep = s
.get_endpoint(test::FAKE, vec!["v2".to_string(), "servers".to_string()])
.unwrap();
assert_eq!(&ep.to_string(), test::URL_WITH_SUFFIX);
}
#[test]
fn test_get_major_version_absent() {
let s = new_simple_sync_session(test::URL);
let res = s.get_major_version(test::FAKE).unwrap();
assert!(res.is_none());
}
#[test]
fn test_get_major_version_present() {
let s = new_sync_session(test::URL);
let res = s.get_major_version(test::FAKE).unwrap();
assert_eq!(res, Some(test::MAJOR_VERSION));
}
#[test]
fn test_pick_api_version_empty() {
let s = new_sync_session(test::URL);
let res = s.pick_api_version(test::FAKE, None).unwrap();
assert!(res.is_none());
}
#[test]
fn test_pick_api_version_empty_vec() {
let s = new_sync_session(test::URL);
let res = s.pick_api_version(test::FAKE, Vec::new()).unwrap();
assert!(res.is_none());
}
#[test]
fn test_pick_api_version() {
let s = new_sync_session(test::URL);
let choice = vec![
ApiVersion(2, 0),
ApiVersion(2, 2),
ApiVersion(2, 4),
ApiVersion(2, 99),
];
let res = s.pick_api_version(test::FAKE, choice).unwrap();
assert_eq!(res, Some(ApiVersion(2, 4)));
}
#[test]
fn test_pick_api_version_option() {
let s = new_sync_session(test::URL);
let res = s
.pick_api_version(test::FAKE, Some(ApiVersion(2, 4)))
.unwrap();
assert_eq!(res, Some(ApiVersion(2, 4)));
}
#[test]
fn test_pick_api_version_impossible() {
let s = new_sync_session(test::URL);
let choice = vec![ApiVersion(2, 0), ApiVersion(2, 99)];
let res = s.pick_api_version(test::FAKE, choice).unwrap();
assert!(res.is_none());
}
#[test]
fn test_stream_empty() {
let s = new_sync_session(test::URL);
let inner = stream::empty::<Result<Bytes, HttpError>>();
let mut st = SyncStream::new(&s, inner);
let mut buffer = Vec::new();
assert_eq!(0, st.read_to_end(&mut buffer).unwrap());
}
#[test]
fn test_stream_all() {
let s = new_sync_session(test::URL);
let data: Vec<Result<Bytes, Error>> = vec![
Ok(Bytes::from(vec![1u8, 2, 3])),
Ok(Bytes::from(vec![4u8])),
Ok(Bytes::from(vec![5u8, 6])),
];
let mut st = SyncStream::new(&s, stream::iter(data.into_iter()));
let mut buffer = Vec::new();
assert_eq!(6, st.read_to_end(&mut buffer).unwrap());
assert_eq!(vec![1, 2, 3, 4, 5, 6], buffer);
}
#[test]
fn test_stream_parts() {
let s = new_sync_session(test::URL);
let data: Vec<Result<Bytes, Error>> = vec![
Ok(Bytes::from(vec![1u8, 2, 3])),
Ok(Bytes::from(vec![4u8])),
Ok(Bytes::from(vec![5u8, 6, 7, 8])),
];
let mut st = SyncStream::new(&s, stream::iter(data.into_iter()));
let mut buffer = [0; 3];
assert_eq!(3, st.read(&mut buffer).unwrap());
assert_eq!([1, 2, 3], buffer);
assert_eq!(1, st.read(&mut buffer).unwrap());
assert_eq!([4, 2, 3], buffer);
assert_eq!(3, st.read(&mut buffer).unwrap());
assert_eq!([5, 6, 7], buffer);
assert_eq!(1, st.read(&mut buffer).unwrap());
assert_eq!([8, 6, 7], buffer);
assert_eq!(0, st.read(&mut buffer).unwrap());
}
#[test]
fn test_body() {
let s = new_sync_session(test::URL);
let data = vec![42; 16_777_000];
let body = SyncBody::new(Cursor::new(data));
let mut st = SyncStream::new(&s, body);
let mut buffer = Vec::new();
assert_eq!(16_777_000, st.read_to_end(&mut buffer).unwrap());
}
#[test]
fn test_body_to_chunk() {
let data = vec![42; 16_777_000];
let body = SyncBody::new(Cursor::new(data));
let _ = Body::from(body);
}
}