Payload reached size limit.

1
closed
laik
laik
Posted 7 months ago

Payload reached size limit. #539

Your issue may already be reported! Please search on the Actix issue tracker before creating one.


use actix_web::body::MessageBody;
use actix_web::dev::{self, ServiceResponse};
use actix_web::{self, body};
use bytes::Bytes;
use futures::{self, StreamExt};

use lazy_static::lazy_static;

use actix_web::http::header::{HeaderMap, HeaderName};
use actix_web::{HttpRequest, HttpResponse};
use awc::{Client, ClientRequest};

use std::net::SocketAddr;
use std::time::Duration;

lazy_static! {
    static ref HEADER_X_FORWARDED_FOR: HeaderName =
        HeaderName::from_lowercase(b"x-forwarded-for").unwrap();
    static ref HOP_BY_HOP_HEADERS: Vec<HeaderName> = vec![
        HeaderName::from_lowercase(b"connection").unwrap(),
        HeaderName::from_lowercase(b"proxy-connection").unwrap(),
        HeaderName::from_lowercase(b"keep-alive").unwrap(),
        HeaderName::from_lowercase(b"proxy-authenticate").unwrap(),
        HeaderName::from_lowercase(b"proxy-authorization").unwrap(),
        HeaderName::from_lowercase(b"te").unwrap(),
        HeaderName::from_lowercase(b"trailer").unwrap(),
        HeaderName::from_lowercase(b"transfer-encoding").unwrap(),
        HeaderName::from_lowercase(b"upgrade").unwrap(),
    ];
    static ref HEADER_TE: HeaderName = HeaderName::from_lowercase(b"te").unwrap();
    static ref HEADER_CONNECTION: HeaderName = HeaderName::from_lowercase(b"connection").unwrap();
}

static DEFAULT_TIMEOUT: Duration = Duration::from_secs(60);

pub struct ReverseProxy<'a> {
    forward_url: &'a str,
    timeout: Duration,
}

fn add_client_ip(fwd_header_value: &mut String, client_addr: SocketAddr) {
    if !fwd_header_value.is_empty() {
        fwd_header_value.push_str(", ");
    }

    let client_ip_str = &format!("{}", client_addr.ip());
    fwd_header_value.push_str(client_ip_str);
}

fn remove_connection_headers(headers: &mut HeaderMap) {
    let mut headers_to_delete: Vec<String> = Vec::new();
    let header_connection = &(*HEADER_CONNECTION);

    if headers.contains_key(header_connection) {
        if let Some(connection_header_value) = headers.get_mut(header_connection.to_string()) {
            for h in connection_header_value
                .to_str()
                .unwrap()
                .split(',')
                .map(|s| s.trim())
            {
                headers_to_delete.push(String::from(h));
            }
        }
    }

    for h in headers_to_delete {
        headers.remove(h);
    }
}

fn remove_request_hop_by_hop_headers(headers: &mut HeaderMap) {
    for h in HOP_BY_HOP_HEADERS.iter() {
        if headers.contains_key(h)
            && (&headers.get(h.to_string()).unwrap() == &""
                || (h == *HEADER_TE && (&headers.get(h.to_string()).unwrap()) == &"trailers"))
        {
            continue;
        }
        headers.remove(h);
    }
}

pub async fn read_body<B>(res: ServiceResponse<B>) -> Bytes
where
    B: MessageBody,
{
    let body = res.into_body();
    body::to_bytes(body)
        .await
        .map_err(Into::<Box<dyn std::error::Error>>::into)
        .expect("error reading test response body")
}

impl<'a> ReverseProxy<'a> {
    pub fn new(forward_url: &'a str) -> ReverseProxy<'a> {
        ReverseProxy {
            forward_url,
            timeout: DEFAULT_TIMEOUT,
        }
    }

    pub fn timeout(mut self, duration: Duration) -> ReverseProxy<'a> {
        self.timeout = duration;
        self
    }

    #[inline]
    fn x_forwarded_for_value(&self, req: &HttpRequest) -> String {
        let mut result = String::new();

        for (key, value) in req.headers() {
            if key == *HEADER_X_FORWARDED_FOR {
                result.push_str(value.to_str().unwrap());
                break;
            }
        }

        // adds client IP address
        // to x-forwarded-for header
        // if it's available
        if let Some(peer_addr) = req.peer_addr() {
            add_client_ip(&mut result, peer_addr);
        }

        result
    }

    fn forward_uri(&self, req: &HttpRequest) -> String {
        let forward_url: &str = self.forward_url;

        let forward_uri = match req.uri().query() {
            Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
            None => format!("{}{}", forward_url, req.uri().path()),
        };

        forward_uri
    }

    pub async fn forward(
        &self,
        req: &mut dev::ServiceRequest,
    ) -> Result<actix_web::HttpResponse, actix_web::Error> {
        let (req, payload) = req.parts_mut();

        let mut forward_req = ClientRequest::from(
            Client::new().request(req.method().clone(), &self.forward_uri(&req)),
        );

        // remove hop-by-hop headers
        remove_connection_headers(forward_req.headers_mut());
        remove_request_hop_by_hop_headers(forward_req.headers_mut());

        let mut body = actix_web::web::BytesMut::new();
        while let Some(item) = payload.next().await {
            body.extend_from_slice(&item?);
        }

        match forward_req
            .insert_header_if_none(("x-forwarded-for", self.x_forwarded_for_value(&req)))
            .insert_header_if_none((actix_web::http::header::USER_AGENT, ""))
            .timeout(self.timeout)
            .send_body(body)
            .await
            .map(|mut resp| async move {
                let mut back_rsp = HttpResponse::build(resp.status());
                // copy headers
                for (key, value) in resp.headers() {
                    if !HOP_BY_HOP_HEADERS.contains(key) {
                        back_rsp.insert_header((key.clone(), value.clone()));
                    }
                }

                let mut back_rsp = match resp.body().await {
                    Ok(body) => back_rsp.body(body),
                    Err(e) => {
                        log::error!("resp payload error {:?}", e.to_string());
                        back_rsp.body(e.to_string())
                    }
                };

                remove_connection_headers(back_rsp.headers_mut());

                back_rsp
            }) {
            Ok(resp) => Ok(resp.await),
            Err(e) => Err(actix_web::error::ErrorInternalServerError(e)),
        }
    }
}

// main.rs

const REVERSE_PROXY_BIND_ADDRESS: &'static str = "0.0.0.0:30000";

async fn proxy(req: dev::ServiceRequest) -> Result<dev::ServiceResponse, actix_web::Error> {
    let mut req = req;
    match ReverseProxy::new("http://0.0.0.0:3000")
        .timeout(Duration::from_secs(5))
        .forward(&mut req)
        .await
    {
        Ok(resp) => Ok(req.into_response(resp)),
        Err(e) => Err(e),
    }
}

#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
    dotenv::dotenv().ok();
    env_logger::init();

    HttpServer::new(|| {
        App::new()
            .app_data(web::PayloadConfig::new(1024 * 1024 * 100)) // 100Mb
            .service(web::service("/{tail}*").finish(proxy))
    })
    .bind(REVERSE_PROXY_BIND_ADDRESS)
    .unwrap()
    .run()
    .await
}

  • Rust Version (I.e, output of rustc -V): rustc 1.64.0-nightly (f8588549c 2022-07-18)
  • Actix Version: actix-web = "4"
laik
laik
Created 7 months ago

solved resp.body().limit(self.body_limit).await.