diff --git a/CHANGELOG.md b/CHANGELOG.md index 89eb9011..ed4f6e85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# 2.2.1 +* Integrated RandomSFX algo (rx/sfx) +* Performance improvements for RandomX variants +* Fixed some cc connection problems # 2.2.0 * Integrated RandomxARQ algo (rx/arq) * Dashboard: diff --git a/CMakeLists.txt b/CMakeLists.txt index 7d3f9059..7e69fc43 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -261,6 +261,8 @@ if (WITH_CC_SERVER OR WITH_CC_CLIENT) src/cc/ClientStatus.cpp src/cc/GPUInfo.cpp) + add_definitions("/DCPPHTTPLIB_USE_POLL") + if (WITH_ZLIB) set(ZLIB_ROOT ${XMRIG_DEPS}) find_package(ZLIB) diff --git a/README.md b/README.md index 6c1ad015..3916de3d 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Full Windows/Linux compatible, and you can mix Linux and Windows miner on one XM ## Additional features of XMRigCC (on top of XMRig) Check the [Coin Configuration](https://github.com/Bendr0id/xmrigCC/wiki/Coin-configurations) guide +* **Support of RandomxSFX variant (algo: "rx/sfx")** * **Support of RandomxARQ variant (algo: "rx/arq")** * **Support of UPX2 variant (algo: "cn-extremelite/upx2")** * **Support of CN-Conceal variant (algo: "cn/conceal")** @@ -133,7 +134,7 @@ xmrigDaemon -o pool.hashvault.pro:5555 -u YOUR_WALLET -p x -k --cc-url=IP_OF_CC_ cn-pico cn-extremelite argon2/chukwa, argon2/wrkz - rx/wow, rx/loki, rx/arq + rx/wow, rx/loki, rx/arq, rx/sfx --coin=COIN specify coin instead of algorithm -o, --url=URL URL of mining server -O, --userpass=U:P username:password pair for mining server diff --git a/src/3rdparty/cpp-httplib/httplib.h b/src/3rdparty/cpp-httplib/httplib.h index afc5aa9d..e831fe0d 100644 --- a/src/3rdparty/cpp-httplib/httplib.h +++ b/src/3rdparty/cpp-httplib/httplib.h @@ -11,6 +11,7 @@ /* * Configuration */ + #ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 #endif @@ -51,6 +52,10 @@ #define CPPHTTPLIB_THREAD_POOL_COUNT 8 #endif +/* + * Headers + */ + #ifdef _WIN32 #ifndef _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS @@ -62,9 +67,9 @@ #if defined(_MSC_VER) #ifdef _WIN64 -typedef __int64 ssize_t; +using ssize_t = __int64; #else -typedef int ssize_t; +using ssize_t = int; #endif #if _MSC_VER < 1900 @@ -100,7 +105,7 @@ typedef int ssize_t; #define strcasecmp _stricmp #endif // strcasecmp -typedef SOCKET socket_t; +using socket_t = SOCKET; #ifdef CPPHTTPLIB_USE_POLL #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) #endif @@ -114,18 +119,19 @@ typedef SOCKET socket_t; #ifdef CPPHTTPLIB_USE_POLL #include #endif +#include #include -#include #include #include #include -typedef int socket_t; +using socket_t = int; #define INVALID_SOCKET (-1) #endif //_WIN32 -#include +#include #include +#include #include #include #include @@ -162,266 +168,303 @@ inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { #include #endif +/* + * Declaration + */ namespace httplib { -namespace detail { + namespace detail { -struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare( - s1.begin(), s1.end(), s2.begin(), s2.end(), - [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); - } -}; + struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } + }; -} // namespace detail + } // namespace detail -enum class HttpVersion { v1_0 = 0, v1_1 }; + enum class HttpVersion { v1_0 = 0, v1_1 }; -typedef std::multimap Headers; + using Headers = std::multimap; -typedef std::multimap Params; -typedef std::smatch Match; + using Params = std::multimap; + using Match = std::smatch; -typedef std::function DataSink; + using DataSink = std::function; -typedef std::function Done; + using Done = std::function; -typedef std::function - ContentProvider; + using ContentProvider = + std::function; -typedef std::function - ContentReceiver; + using ContentProviderWithCloser = + std::function; -typedef std::function Progress; + using Progress = std::function; -struct Response; -typedef std::function ResponseHandler; + struct Response; + using ResponseHandler = std::function; -struct MultipartFile { - std::string filename; - std::string content_type; - size_t offset = 0; - size_t length = 0; -}; -typedef std::multimap MultipartFiles; + struct MultipartFile { + std::string filename; + std::string content_type; + std::string content; + }; + using MultipartFiles = std::multimap; -struct MultipartFormData { - std::string name; - std::string content; - std::string filename; - std::string content_type; -}; -typedef std::vector MultipartFormDataItems; + struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; + }; + using MultipartFormDataItems = std::vector; -typedef std::pair Range; -typedef std::vector Ranges; + using ContentReceiver = + std::function; -struct Request { - std::string method; - std::string path; - Headers headers; - std::string body; + using MultipartContentHeader = + std::function; - // for server - std::string remoteAddr; - std::string version; - std::string target; - Params params; - MultipartFiles files; - Ranges ranges; - Match matches; + using MultipartContentReceiver = + std::function; - // for client - size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; - ResponseHandler response_handler; - ContentReceiver content_receiver; - Progress progress; + class ContentReader { + public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, MultipartContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(receiver); + } + + Reader reader_; + MultipartReader muitlpart_reader_; + }; + + using Range = std::pair; + using Ranges = std::vector; + + struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + // for server + std::string version; + std::string target; + Params params; + MultipartFiles files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl; + const SSL *ssl; #endif - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); - bool has_param(const char *key) const; - std::string get_param_value(const char *key, size_t id = 0) const; - size_t get_param_value_count(const char *key) const; + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; - bool has_file(const char *key) const; - MultipartFile get_file_value(const char *key) const; -}; + bool is_multipart_form_data() const; -struct Response { - std::string version; - int status; - Headers headers; - std::string body; + bool has_file(const char *key) const; + MultipartFile get_file_value(const char *key) const; - bool has_header(const char *key) const; - std::string get_header_value(const char *key, size_t id = 0) const; - size_t get_header_value_count(const char *key) const; - void set_header(const char *key, const char *val); - void set_header(const char *key, const std::string &val); + // private members... + size_t content_length; + ContentProvider content_provider; + }; - void set_redirect(const char *uri); - void set_content(const char *s, size_t n, const char *content_type); - void set_content(const std::string &s, const char *content_type); + struct Response { + std::string version; + int status; + Headers headers; + std::string body; - void set_content_provider( + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(const std::string &s, const char *content_type); + + void set_content_provider( size_t length, std::function provider, std::function resource_releaser = [] {}); - void set_chunked_content_provider( + void set_chunked_content_provider( std::function provider, std::function resource_releaser = [] {}); - Response() : status(-1), content_provider_resource_length(0) {} + Response() : status(-1), content_length(0) {} - ~Response() { - if (content_provider_resource_releaser) { - content_provider_resource_releaser(); - } - } - - size_t content_provider_resource_length; - ContentProvider content_provider; - std::function content_provider_resource_releaser; -}; - -class Stream { -public: - virtual ~Stream() {} - virtual int read(char *ptr, size_t size) = 0; - virtual int write(const char *ptr, size_t size1) = 0; - virtual int write(const char *ptr) = 0; - virtual int write(const std::string &s) = 0; - virtual std::string get_remote_addr() const = 0; - - template - int write_format(const char *fmt, const Args &... args); -}; - -class SocketStream : public Stream { -public: - SocketStream(socket_t sock); - virtual ~SocketStream(); - - virtual int read(char *ptr, size_t size); - virtual int write(const char *ptr, size_t size); - virtual int write(const char *ptr); - virtual int write(const std::string &s); - virtual std::string get_remote_addr() const; - -private: - socket_t sock_; -}; - -class BufferStream : public Stream { -public: - BufferStream() {} - virtual ~BufferStream() {} - - virtual int read(char *ptr, size_t size); - virtual int write(const char *ptr, size_t size); - virtual int write(const char *ptr); - virtual int write(const std::string &s); - virtual std::string get_remote_addr() const; - - const std::string &get_buffer() const; - -private: - std::string buffer; -}; - -class TaskQueue { -public: - TaskQueue() {} - virtual ~TaskQueue() {} - virtual void enqueue(std::function fn) = 0; - virtual void shutdown() = 0; -}; - -#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 -class ThreadPool : public TaskQueue { -public: - ThreadPool(size_t n) : shutdown_(false) { - while (n) { - auto t = std::make_shared(worker(*this)); - threads_.push_back(t); - n--; - } - } - - ThreadPool(const ThreadPool &) = delete; - virtual ~ThreadPool() {} - - virtual void enqueue(std::function fn) override { - std::unique_lock lock(mutex_); - jobs_.push_back(fn); - cond_.notify_one(); - } - - virtual void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; - } - - cond_.notify_all(); - - // Join... - for (auto t : threads_) { - t->join(); - } - } - -private: - struct worker { - worker(ThreadPool &pool) : pool_(pool) {} - - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); - - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } - - assert(true == static_cast(fn)); - fn(); + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); } } - ThreadPool &pool_; + // private members... + size_t content_length; + ContentProviderWithCloser content_provider; + std::function content_provider_resource_releaser; }; - friend struct worker; - std::vector> threads_; - std::list> jobs_; + class Stream { + public: + virtual ~Stream() = default; + virtual int read(char *ptr, size_t size) = 0; + virtual int write(const char *ptr, size_t size1) = 0; + virtual int write(const char *ptr) = 0; + virtual int write(const std::string &s) = 0; + virtual std::string get_remote_addr() const = 0; - bool shutdown_; + template + int write_format(const char *fmt, const Args &... args); + }; - std::condition_variable cond_; - std::mutex mutex_; -}; -#else -class Threads : public TaskQueue { + class SocketStream : public Stream { + public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + int write(const char *ptr) override; + int write(const std::string &s) override; + std::string get_remote_addr() const override; + + private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + }; + + class BufferStream : public Stream { + public: + BufferStream() = default; + ~BufferStream() override = default; + + int read(char *ptr, size_t size) override; + int write(const char *ptr, size_t size) override; + int write(const char *ptr) override; + int write(const std::string &s) override; + std::string get_remote_addr() const override; + + const std::string &get_buffer() const; + + private: + std::string buffer; + }; + + class TaskQueue { + public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + }; + +#if CPPHTTPLIB_THREAD_POOL_COUNT > 0 + class ThreadPool : public TaskQueue { + public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + + private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; + }; +#elif CPPHTTPLIB_THREAD_POOL_COUNT == 0 + class Threads : public TaskQueue { public: Threads() : running_threads_(0) {} virtual ~Threads() {} @@ -454,596 +497,694 @@ private: std::mutex running_threads_mutex_; int running_threads_; }; +#else +class NoThread : public TaskQueue { +public: + NoThread() {} + virtual ~NoThread() {} + + virtual void enqueue(std::function fn) override { fn(); } + + virtual void shutdown() override {} +}; #endif -class Server { -public: - typedef std::function Handler; - typedef std::function Logger; + class Server { + public: + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Logger = std::function; - Server(); + Server(); - virtual ~Server(); + virtual ~Server(); - virtual bool is_valid() const; + virtual bool is_valid() const; - Server &Get(const char *pattern, Handler handler); - Server &Post(const char *pattern, Handler handler); + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Options(const char *pattern, Handler handler); - Server &Put(const char *pattern, Handler handler); - Server &Patch(const char *pattern, Handler handler); - Server &Delete(const char *pattern, Handler handler); - Server &Options(const char *pattern, Handler handler); + bool set_base_dir(const char *dir, const char *mount_point = nullptr); + void set_file_request_handler(Handler handler); - bool set_base_dir(const char *path); - void set_file_request_handler(Handler handler); + void set_error_handler(Handler handler); + void set_logger(Logger logger); - void set_error_handler(Handler handler); - void set_logger(Logger logger); + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); - void set_keep_alive_max_count(size_t count); - void set_payload_max_length(size_t length); + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); - bool bind_to_port(const char *host, int port, int socket_flags = 0); - int bind_to_any_port(const char *host, int socket_flags = 0); - bool listen_after_bind(); + bool listen(const char *host, int port, int socket_flags = 0); - bool listen(const char *host, int port, int socket_flags = 0); + bool is_running() const; + void stop(); - bool is_running() const; - void stop(); + std::function new_task_queue; - std::function new_task_queue; + protected: + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request); -protected: - bool process_request(Stream &strm, bool last_connection, - bool &connection_close, - std::function setup_request); + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; - size_t keep_alive_max_count_; - size_t payload_max_length_; + private: + using Handlers = std::vector>; + using HandersForContentReader = + std::vector>; -private: - typedef std::vector> Handlers; + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); - socket_t create_server_socket(const char *host, int port, - int socket_flags) const; - int bind_internal(const char *host, int port, int socket_flags); - bool listen_internal(); + bool routing(Request &req, Response &res, Stream &strm, bool last_connection); + bool handle_file_request(Request &req, Response &res); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandersForContentReader &handlers); - bool routing(Request &req, Response &res); - bool handle_file_request(Request &req, Response &res); - bool dispatch_request(Request &req, Response &res, Handlers &handlers); - - bool parse_request_line(const char *s, Request &req); - bool write_response(Stream &strm, bool last_connection, const Request &req, + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, bool last_connection, Request &req, Response &res); - bool write_content_with_provider(Stream &strm, const Request &req, - Response &res, const std::string &boundary, - const std::string &content_type); + bool read_content_with_content_receiver(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + MultipartContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + MultipartContentReceiver multipart_receiver); - virtual bool process_and_close_socket(socket_t sock); + virtual bool process_and_close_socket(socket_t sock); - std::atomic is_running_; - std::atomic svr_sock_; - std::string base_dir_; - Handler file_request_handler_; - Handlers get_handlers_; - Handlers post_handlers_; - Handlers put_handlers_; - Handlers patch_handlers_; - Handlers delete_handlers_; - Handlers options_handlers_; - Handler error_handler_; - Logger logger_; -}; + std::atomic is_running_; + std::atomic svr_sock_; + std::vector> base_dirs_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + }; -class Client { -public: - Client(const char *host, int port = 80, time_t timeout_sec = 300); + class Client { + public: + explicit Client(const char *host, int port = 80, time_t timeout_sec = 300); - virtual ~Client(); + virtual ~Client(); - virtual bool is_valid() const; + virtual bool is_valid() const; - std::shared_ptr Get(const char *path); + std::shared_ptr Get(const char *path); - std::shared_ptr Get(const char *path, const Headers &headers); + std::shared_ptr Get(const char *path, const Headers &headers); - std::shared_ptr Get(const char *path, Progress progress); + std::shared_ptr Get(const char *path, Progress progress); - std::shared_ptr Get(const char *path, const Headers &headers, - Progress progress); + std::shared_ptr Get(const char *path, const Headers &headers, + Progress progress); - std::shared_ptr Get(const char *path, - ContentReceiver content_receiver); + std::shared_ptr Get(const char *path, + ContentReceiver content_receiver); - std::shared_ptr Get(const char *path, const Headers &headers, - ContentReceiver content_receiver); + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); - std::shared_ptr - Get(const char *path, ContentReceiver content_receiver, Progress progress); + std::shared_ptr + Get(const char *path, ContentReceiver content_receiver, Progress progress); - std::shared_ptr Get(const char *path, const Headers &headers, - ContentReceiver content_receiver, - Progress progress); + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); - std::shared_ptr Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver); + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); - std::shared_ptr Get(const char *path, const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver, - Progress progress); + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); - std::shared_ptr Head(const char *path); + std::shared_ptr Head(const char *path); - std::shared_ptr Head(const char *path, const Headers &headers); + std::shared_ptr Head(const char *path, const Headers &headers); - std::shared_ptr Post(const char *path, const std::string &body, - const char *content_type); + std::shared_ptr Post(const char *path, const std::string &body, + const char *content_type, + bool compress = false); - std::shared_ptr Post(const char *path, const Headers &headers, - const std::string &body, - const char *content_type); - - std::shared_ptr Post(const char *path, const Params ¶ms); - - std::shared_ptr Post(const char *path, const Headers &headers, - const Params ¶ms); - - std::shared_ptr Post(const char *path, - const MultipartFormDataItems &items); - - std::shared_ptr Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items); - - std::shared_ptr Put(const char *path, const std::string &body, - const char *content_type); - - std::shared_ptr Put(const char *path, const Headers &headers, - const std::string &body, - const char *content_type); - - std::shared_ptr Patch(const char *path, const std::string &body, - const char *content_type); - - std::shared_ptr Patch(const char *path, const Headers &headers, - const std::string &body, - const char *content_type); - - std::shared_ptr Delete(const char *path); - - std::shared_ptr Delete(const char *path, const std::string &body, - const char *content_type); - - std::shared_ptr Delete(const char *path, const Headers &headers); - - std::shared_ptr Delete(const char *path, const Headers &headers, + std::shared_ptr Post(const char *path, const Headers &headers, const std::string &body, - const char *content_type); + const char *content_type, + bool compress = false); - std::shared_ptr Options(const char *path); + std::shared_ptr Post(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); - std::shared_ptr Options(const char *path, const Headers &headers); + std::shared_ptr Post(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); - bool send(const Request &req, Response &res); + std::shared_ptr Post(const char *path, const Params ¶ms, + bool compress = false); - bool send(const std::vector &requests, - std::vector &responses); + std::shared_ptr Post(const char *path, const Headers &headers, + const Params ¶ms, bool compress = false); - void set_keep_alive_max_count(size_t count); + std::shared_ptr Post(const char *path, + const MultipartFormDataItems &items, + bool compress = false); - void follow_location(bool on); + std::shared_ptr Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, + bool compress = false); -protected: - bool process_request(Stream &strm, const Request &req, Response &res, - bool last_connection, bool &connection_close); + std::shared_ptr Put(const char *path, const std::string &body, + const char *content_type, + bool compress = false); - const std::string host_; - const int port_; - time_t timeout_sec_; - const std::string host_and_port_; - size_t keep_alive_max_count_; - size_t follow_location_; + std::shared_ptr Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type, + bool compress = false); -private: - socket_t create_client_socket() const; - bool read_response_line(Stream &strm, Response &res); - void write_request(Stream &strm, const Request &req, bool last_connection); - bool redirect(const Request &req, Response &res); + std::shared_ptr Put(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); - virtual bool process_and_close_socket( + std::shared_ptr Put(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); + + std::shared_ptr Patch(const char *path, const std::string &body, + const char *content_type, + bool compress = false); + + std::shared_ptr Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type, + bool compress = false); + + std::shared_ptr Patch(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); + + std::shared_ptr Patch(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress = false); + + std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Delete(const char *path, const Headers &headers); + + std::shared_ptr Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Options(const char *path); + + std::shared_ptr Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector &requests, + std::vector &responses); + + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + + void follow_location(bool on); + + protected: + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); + + const std::string host_; + const int port_; + time_t timeout_sec_; + const std::string host_and_port_; + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t follow_location_; + + private: + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + void write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + + std::shared_ptr + send_with_content_provider(const char *method, const char *path, + const Headers &headers, const std::string &body, + size_t content_length, + ContentProvider content_provider, + const char *content_type, bool compress); + + virtual bool process_and_close_socket( socket_t sock, size_t request_count, std::function - callback); + callback); - virtual bool is_ssl() const; -}; + virtual bool is_ssl() const; + }; -inline void Get(std::vector &requests, const char *path, - const Headers &headers) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - requests.emplace_back(std::move(req)); -} + inline void Get(std::vector &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); + } -inline void Get(std::vector &requests, const char *path) { - Get(requests, path, Headers()); -} + inline void Get(std::vector &requests, const char *path) { + Get(requests, path, Headers()); + } -inline void Post(std::vector &requests, const char *path, - const Headers &headers, const std::string &body, - const char *content_type) { - Request req; - req.method = "POST"; - req.path = path; - req.headers = headers; - req.headers.emplace("Content-Type", content_type); - req.body = body; - requests.emplace_back(std::move(req)); -} + inline void Post(std::vector &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + req.headers.emplace("Content-Type", content_type); + req.body = body; + requests.emplace_back(std::move(req)); + } -inline void Post(std::vector &requests, const char *path, - const std::string &body, const char *content_type) { - Post(requests, path, Headers(), body, content_type); -} + inline void Post(std::vector &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); + } #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream : public Stream { -public: - SSLSocketStream(socket_t sock, SSL *ssl); - virtual ~SSLSocketStream(); + class SSLSocketStream : public Stream { + public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + virtual ~SSLSocketStream(); - virtual int read(char *ptr, size_t size); - virtual int write(const char *ptr, size_t size); - virtual int write(const char *ptr); - virtual int write(const std::string &s); - virtual std::string get_remote_addr() const; + virtual int read(char *ptr, size_t size); + virtual int write(const char *ptr, size_t size); + virtual int write(const char *ptr); + virtual int write(const std::string &s); + virtual std::string get_remote_addr() const; -private: - socket_t sock_; - SSL *ssl_; -}; + private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + }; -class SSLServer : public Server { -public: - SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path = nullptr, - const char *client_ca_cert_dir_path = nullptr); + class SSLServer : public Server { + public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); - virtual ~SSLServer(); + virtual ~SSLServer(); - virtual bool is_valid() const; + virtual bool is_valid() const; -private: - virtual bool process_and_close_socket(socket_t sock); + private: + virtual bool process_and_close_socket(socket_t sock); - SSL_CTX *ctx_; - std::mutex ctx_mutex_; -}; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + }; -class SSLClient : public Client { -public: - SSLClient(const char *host, int port = 443, time_t timeout_sec = 300, - const char *client_cert_path = nullptr, - const char *client_key_path = nullptr); + class SSLClient : public Client { + public: + SSLClient(const char *host, int port = 443, time_t timeout_sec = 300, + const char *client_cert_path = nullptr, + const char *client_key_path = nullptr); - virtual ~SSLClient(); + virtual ~SSLClient(); - virtual bool is_valid() const; + virtual bool is_valid() const; - void set_ca_cert_path(const char *ca_ceert_file_path, - const char *ca_cert_dir_path = nullptr); - void enable_server_certificate_verification(bool enabled); + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + void enable_server_certificate_verification(bool enabled); - long get_openssl_verify_result() const; + long get_openssl_verify_result() const; - SSL_CTX* ssl_context() const noexcept; + SSL_CTX *ssl_context() const noexcept; -private: - virtual bool process_and_close_socket( + private: + virtual bool process_and_close_socket( socket_t sock, size_t request_count, std::function - callback); - virtual bool is_ssl() const; + callback); + virtual bool is_ssl() const; - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; - SSL_CTX *ctx_; - std::mutex ctx_mutex_; - std::vector host_components_; - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - bool server_certificate_verification_ = false; - long verify_result_ = 0; -}; + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector host_components_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = false; + long verify_result_ = 0; + }; #endif /* * Implementation */ -namespace detail { -inline bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { - v = c - '0'; - return true; - } else if ('A' <= c && c <= 'F') { - v = c - 'A' + 10; - return true; - } else if ('a' <= c && c <= 'f') { - v = c - 'a' + 10; - return true; - } - return false; -} + namespace detail { -inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, - int &val) { - if (i >= s.size()) { return false; } - - val = 0; - for (; cnt; i++, cnt--) { - if (!s[i]) { return false; } - int v = 0; - if (is_hex(s[i], v)) { - val = val * 16 + v; - } else { + inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } return false; } - } - return true; -} -inline std::string from_i_to_hex(size_t n) { - const char *charset = "0123456789abcdef"; - std::string ret; - do { - ret = charset[n & 15] + ret; - n >>= 4; - } while (n > 0); - return ret; -} + inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } -inline size_t to_utf8(int code, char *buff) { - if (code < 0x0080) { - buff[0] = (code & 0x7F); - return 1; - } else if (code < 0x0800) { - buff[0] = (0xC0 | ((code >> 6) & 0x1F)); - buff[1] = (0x80 | (code & 0x3F)); - return 2; - } else if (code < 0xD800) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0xE000) { // D800 - DFFF is invalid... - return 0; - } else if (code < 0x10000) { - buff[0] = (0xE0 | ((code >> 12) & 0xF)); - buff[1] = (0x80 | ((code >> 6) & 0x3F)); - buff[2] = (0x80 | (code & 0x3F)); - return 3; - } else if (code < 0x110000) { - buff[0] = (0xF0 | ((code >> 18) & 0x7)); - buff[1] = (0x80 | ((code >> 12) & 0x3F)); - buff[2] = (0x80 | ((code >> 6) & 0x3F)); - buff[3] = (0x80 | (code & 0x3F)); - return 4; - } + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; + } - // NOTREACHED - return 0; -} + inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; + } + + inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = (0xC0 | ((code >> 6) & 0x1F)); + buff[1] = (0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = (0xE0 | ((code >> 12) & 0xF)); + buff[1] = (0x80 | ((code >> 6) & 0x3F)); + buff[2] = (0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = (0xF0 | ((code >> 18) & 0x7)); + buff[1] = (0x80 | ((code >> 12) & 0x3F)); + buff[2] = (0x80 | ((code >> 6) & 0x3F)); + buff[3] = (0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; + } // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c -inline std::string base64_encode(const std::string &in) { - static const auto lookup = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string out; - out.reserve(in.size()); + std::string out; + out.reserve(in.size()); - int val = 0; - int valb = -6; + int val = 0; + int valb = -6; - for (uint8_t c : in) { - val = (val << 8) + c; - valb += 8; - while (valb >= 0) { - out.push_back(lookup[(val >> valb) & 0x3F]); - valb -= 6; - } - } - - if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } - - while (out.size() % 4) { - out.push_back('='); - } - - return out; -} - -inline bool is_file(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); -} - -inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - -inline bool is_valid_path(const std::string &path) { - size_t level = 0; - size_t i = 0; - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - - while (i < path.size()) { - // Read component - auto beg = i; - while (i < path.size() && path[i] != '/') { - i++; - } - - auto len = i - beg; - assert(len > 0); - - if (!path.compare(beg, len, ".")) { - ; - } else if (!path.compare(beg, len, "..")) { - if (level == 0) { return false; } - level--; - } else { - level++; - } - - // Skip slash - while (i < path.size() && path[i] == '/') { - i++; - } - } - - return true; -} - -inline void read_file(const std::string &path, std::string &out) { - std::ifstream fs(path, std::ios_base::binary); - fs.seekg(0, std::ios_base::end); - auto size = fs.tellg(); - fs.seekg(0); - out.resize(static_cast(size)); - fs.read(&out[0], size); -} - -inline std::string file_extension(const std::string &path) { - std::smatch m; - auto pat = std::regex("\\.([a-zA-Z0-9]+)$"); - if (std::regex_search(path, m, pat)) { return m[1].str(); } - return std::string(); -} - -template void split(const char *b, const char *e, char d, Fn fn) { - int i = 0; - int beg = 0; - - while (e ? (b + i != e) : (b[i] != '\0')) { - if (b[i] == d) { - fn(&b[beg], &b[i]); - beg = i + 1; - } - i++; - } - - if (i) { fn(&b[beg], &b[i]); } -} - -// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` -// to store data. The call can set memory on stack for performance. -class stream_line_reader { -public: - stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) - : strm_(strm), fixed_buffer_(fixed_buffer), - fixed_buffer_size_(fixed_buffer_size) {} - - const char *ptr() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_; - } else { - return glowable_buffer_.data(); - } - } - - size_t size() const { - if (glowable_buffer_.empty()) { - return fixed_buffer_used_size_; - } else { - return glowable_buffer_.size(); - } - } - - bool getline() { - fixed_buffer_used_size_ = 0; - glowable_buffer_.clear(); - - for (size_t i = 0;; i++) { - char byte; - auto n = strm_.read(&byte, 1); - - if (n < 0) { - return false; - } else if (n == 0) { - if (i == 0) { - return false; - } else { - break; + for (uint8_t c : in) { + val = (val << 8) + c; + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; } } - append(byte); + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } - if (byte == '\n') { break; } - } - - return true; - } - -private: - void append(char c) { - if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { - fixed_buffer_[fixed_buffer_used_size_++] = c; - fixed_buffer_[fixed_buffer_used_size_] = '\0'; - } else { - if (glowable_buffer_.empty()) { - assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); - glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + while (out.size() % 4) { + out.push_back('='); } - glowable_buffer_ += c; + + return out; } - } - Stream &strm_; - char *fixed_buffer_; - const size_t fixed_buffer_size_; - size_t fixed_buffer_used_size_; - std::string glowable_buffer_; -}; + inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); + } -inline int close_socket(socket_t sock) { + inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); + } + + inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; + } + + inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); + } + + inline std::string file_extension(const std::string &path) { + std::smatch m; + auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); + } + + template void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } + } + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. + class stream_line_reader { + public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } + } + + return true; + } + + private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; + }; + + inline int close_socket(socket_t sock) { #ifdef _WIN32 - return closesocket(sock); + return closesocket(sock); #else - return close(sock); + return close(sock); #endif -} + } -inline int select_read(socket_t sock, time_t sec, time_t usec) { + inline int select_read(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - return poll(&pfd_read, 1, timeout); + return poll(&pfd_read, 1, timeout); #else - fd_set fds; + fd_set fds; FD_ZERO(&fds); FD_SET(sock, &fds); @@ -1053,26 +1194,27 @@ inline int select_read(socket_t sock, time_t sec, time_t usec) { return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); #endif -} + } -inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { + inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN | POLLOUT; + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; - auto timeout = static_cast(sec * 1000 + usec / 1000); + auto timeout = static_cast(sec * 1000 + usec / 1000); - if (poll(&pfd_read, 1, timeout) > 0 && - pfd_read.revents & (POLLIN | POLLOUT)) { - int error = 0; - socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) >= 0 && - !error; - } - return false; + if (poll(&pfd_read, 1, timeout) > 0 && + pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; #else - fd_set fdsr; + fd_set fdsr; FD_ZERO(&fdsr); FD_SET(sock, &fdsr); @@ -1087,58 +1229,61 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { int error = 0; socklen_t len = sizeof(error); - return getsockopt(sock, SOL_SOCKET, SO_ERROR, (char *)&error, &len) >= 0 && + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && !error; } return false; #endif -} - -template -inline bool process_and_close_socket(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, T callback) { - assert(keep_alive_max_count > 0); - - bool ret = false; - - if (keep_alive_max_count > 1) { - auto count = keep_alive_max_count; - while (count > 0 && - (is_client_request || - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { - SocketStream strm(sock); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(strm, last_connection, connection_close); - if (!ret || connection_close) { break; } - - count--; } - } else { - SocketStream strm(sock); - auto dummy_connection_close = false; - ret = callback(strm, true, dummy_connection_close); - } - close_socket(sock); - return ret; -} + template + inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); -inline int shutdown_socket(socket_t sock) { + bool ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + close_socket(sock); + return ret; + } + + inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 - return shutdown(sock, SD_BOTH); + return shutdown(sock, SD_BOTH); #else - return shutdown(sock, SHUT_RDWR); + return shutdown(sock, SHUT_RDWR); #endif -} + } -template -socket_t create_socket(const char *host, int port, Fn fn, - int socket_flags = 0) { + template + socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { #ifdef _WIN32 -#define SO_SYNCHRONOUS_NONALERT 0x20 + #define SO_SYNCHRONOUS_NONALERT 0x20 #define SO_OPENTYPE 0x7008 int opt = SO_SYNCHRONOUS_NONALERT; @@ -1146,147 +1291,168 @@ socket_t create_socket(const char *host, int port, Fn fn, sizeof(opt)); #endif - // Get address info - struct addrinfo hints; - struct addrinfo *result; + // Get address info + struct addrinfo hints; + struct addrinfo *result; - memset(&hints, 0, sizeof(struct addrinfo)); - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_flags = socket_flags; - hints.ai_protocol = 0; + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; - auto service = std::to_string(port); + auto service = std::to_string(port); - if (getaddrinfo(host, service.c_str(), &hints, &result)) { - return INVALID_SOCKET; - } + if (getaddrinfo(host, service.c_str(), &hints, &result)) { + return INVALID_SOCKET; + } - for (auto rp = result; rp; rp = rp->ai_next) { - // Create a socket + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket #ifdef _WIN32 - auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); #else - auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); #endif - if (sock == INVALID_SOCKET) { continue; } + if (sock == INVALID_SOCKET) { continue; } #ifndef _WIN32 - if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } #endif - // Make 'reuse address' option available - int yes = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)); + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); #ifdef SO_REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); #endif - // bind or connect - if (fn(sock, *rp)) { + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + freeaddrinfo(result); - return sock; + return INVALID_SOCKET; } - close_socket(sock); - } - - freeaddrinfo(result); - return INVALID_SOCKET; -} - -inline void set_nonblocking(socket_t sock, bool nonblocking) { + inline void set_nonblocking(socket_t sock, bool nonblocking) { #ifdef _WIN32 - auto flags = nonblocking ? 1UL : 0UL; + auto flags = nonblocking ? 1UL : 0UL; ioctlsocket(sock, FIONBIO, &flags); #else - auto flags = fcntl(sock, F_GETFL, 0); - fcntl(sock, F_SETFL, - nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); #endif -} - -inline bool is_connection_error() { -#ifdef _WIN32 - return WSAGetLastError() != WSAEWOULDBLOCK; -#else - return errno != EINPROGRESS; -#endif -} - -inline std::string get_remote_addr(socket_t sock) { - struct sockaddr_storage addr; - socklen_t len = sizeof(addr); - - if (!getpeername(sock, reinterpret_cast(&addr), &len)) { - char ipstr[NI_MAXHOST]; - - if (!getnameinfo(reinterpret_cast(&addr), len, ipstr, sizeof(ipstr), - nullptr, 0, NI_NUMERICHOST)) { - return ipstr; } - } - return std::string(); -} + inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif + } -inline const char *find_content_type(const std::string &path) { - auto ext = file_extension(path); - if (ext == "txt") { - return "text/plain"; - } else if (ext == "html") { - return "text/html"; - } else if (ext == "css") { - return "text/css"; - } else if (ext == "jpeg" || ext == "jpg") { - return "image/jpg"; - } else if (ext == "png") { - return "image/png"; - } else if (ext == "gif") { - return "image/gif"; - } else if (ext == "svg") { - return "image/svg+xml"; - } else if (ext == "ico") { - return "image/x-icon"; - } else if (ext == "json") { - return "application/json"; - } else if (ext == "pdf") { - return "application/pdf"; - } else if (ext == "js") { - return "application/javascript"; - } else if (ext == "xml") { - return "application/xml"; - } else if (ext == "xhtml") { - return "application/xhtml+xml"; - } - return nullptr; -} + inline socket_t create_client_socket( + const char *host, int port, time_t timeout_sec) { + return create_socket( + host, port, [=](socket_t sock, struct addrinfo &ai) -> bool { + set_nonblocking(sock, true); -inline const char *status_message(int status) { - switch (status) { - case 200: return "OK"; - case 206: return "Partial Content"; - case 301: return "Moved Permanently"; - case 302: return "Found"; - case 303: return "See Other"; - case 304: return "Not Modified"; - case 400: return "Bad Request"; - case 401: return "Unauthorized"; - case 403: return "Forbidden"; - case 404: return "Not Found"; - case 413: return "Payload Too Large"; - case 414: return "Request-URI Too Long"; - case 415: return "Unsupported Media Type"; - case 416: return "Range Not Satisfiable"; + auto ret = ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } - default: - case 500: return "Internal Server Error"; - } -} + set_nonblocking(sock, false); + return true; + }); + } + + inline std::string get_remote_addr(socket_t sock) { + struct sockaddr_storage addr; + socklen_t len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), &len)) { + std::array ipstr{}; + + if (!getnameinfo(reinterpret_cast(&addr), len, + ipstr.data(), ipstr.size(), nullptr, 0, NI_NUMERICHOST)) { + return ipstr.data(); + } + } + + return std::string(); + } + + inline const char *find_content_type(const std::string &path) { + auto ext = file_extension(path); + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; + } + + inline const char *status_message(int status) { + switch (status) { + case 200: return "OK"; + case 206: return "Partial Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 413: return "Payload Too Large"; + case 414: return "Request-URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + + default: + case 500: return "Internal Server Error"; + } + } #ifdef CPPHTTPLIB_ZLIB_SUPPORT -inline bool can_compress(const std::string &content_type) { + inline bool can_compress(const std::string &content_type) { return !content_type.find("text/") || content_type == "image/svg+xml" || content_type == "application/javascript" || content_type == "application/json" || @@ -1305,18 +1471,18 @@ inline bool compress(std::string &content) { if (ret != Z_OK) { return false; } strm.avail_in = content.size(); - strm.next_in = const_cast(reinterpret_cast(content.data())); + strm.next_in = + const_cast(reinterpret_cast(content.data())); std::string compressed; - const auto bufsiz = 16384; - char buff[bufsiz]; + std::array buff{}; do { - strm.avail_out = bufsiz; - strm.next_out = reinterpret_cast(buff); + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); ret = deflate(&strm, Z_FINISH); assert(ret != Z_STREAM_ERROR); - compressed.append(buff, bufsiz - strm.avail_out); + compressed.append(buff.data(), buff.size() - strm.avail_out); } while (strm.avail_out == 0); assert(ret == Z_STREAM_END); @@ -1350,13 +1516,12 @@ public: int ret = Z_OK; strm.avail_in = data_length; - strm.next_in = const_cast(reinterpret_cast(data)); + strm.next_in = const_cast(reinterpret_cast(data)); - const auto bufsiz = 16384; - char buff[bufsiz]; + std::array buff{}; do { - strm.avail_out = bufsiz; - strm.next_out = reinterpret_cast(buff); + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); ret = inflate(&strm, Z_NO_FLUSH); assert(ret != Z_STREAM_ERROR); @@ -1366,10 +1531,12 @@ public: case Z_MEM_ERROR: inflateEnd(&strm); return false; } - if (!callback(buff, bufsiz - strm.avail_out)) { return false; } + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } } while (strm.avail_out == 0); - return ret == Z_STREAM_END; + return ret == Z_OK || ret == Z_STREAM_END; } private: @@ -1378,148 +1545,162 @@ private: }; #endif -inline bool has_header(const Headers &headers, const char *key) { - return headers.find(key) != headers.end(); -} - -inline const char *get_header_value(const Headers &headers, const char *key, - size_t id = 0, const char *def = nullptr) { - auto it = headers.find(key); - std::advance(it, id); - if (it != headers.end()) { return it->second.c_str(); } - return def; -} - -inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, - int def = 0) { - auto it = headers.find(key); - if (it != headers.end()) { - return std::strtoull(it->second.data(), nullptr, 10); - } - return def; -} - -inline bool read_headers(Stream &strm, Headers &headers) { - static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)"); - - const auto bufsiz = 2048; - char buf[bufsiz]; - - stream_line_reader reader(strm, buf, bufsiz); - - for (;;) { - if (!reader.getline()) { return false; } - if (!strcmp(reader.ptr(), "\r\n")) { break; } - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - auto key = std::string(m[1]); - auto val = std::string(m[2]); - headers.emplace(key, val); + inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); } - } - return true; -} - -typedef std::function - ContentReceiverCore; - -inline bool read_content_with_length(Stream &strm, uint64_t len, - Progress progress, - ContentReceiverCore out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return false; } - - if (!out(buf, n)) { return false; } - - r += n; - - if (progress) { - if (!progress(r, len)) { return false; } + inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, id); + if (it != headers.end()) { return it->second.c_str(); } + return def; } - } - return true; -} + inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + int def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; + } -inline void skip_content_with_length(Stream &strm, uint64_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - uint64_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += n; - } -} + inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { + continue; // Skip invalid line. + } + + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } + + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"((.+?):[\t ]*(.+))"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } -inline bool read_content_without_length(Stream &strm, ContentReceiverCore out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - for (;;) { - auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); - if (n < 0) { - return false; - } else if (n == 0) { return true; } - if (!out(buf, n)) { return false; } - } - return true; -} + inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; -inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) { - const auto bufsiz = 16; - char buf[bufsiz]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } - stream_line_reader reader(strm, buf, bufsiz); + if (!out(buf, n)) { return false; } - if (!reader.getline()) { return false; } + r += n; - auto chunk_len = std::stoi(reader.ptr(), 0, 16); + if (progress) { + if (!progress(r, len)) { return false; } + } + } - while (chunk_len > 0) { - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { - return false; + return true; } - if (!reader.getline()) { return false; } + inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, std::min(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += n; + } + } - if (strcmp(reader.ptr(), "\r\n")) { break; } + inline bool read_content_without_length(Stream &strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, n)) { return false; } + } - if (!reader.getline()) { return false; } + return true; + } - chunk_len = std::stoi(reader.ptr(), 0, 16); - } + inline bool read_content_chunked(Stream &strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; - if (chunk_len == 0) { - // Reader terminator after chunks - if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) return false; - } + stream_line_reader line_reader(strm, buf, bufsiz); - return true; -} + if (!line_reader.getline()) { return false; } -inline bool is_chunked_transfer_encoding(const Headers &headers) { - return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), - "chunked"); -} + auto chunk_len = std::stoi(line_reader.ptr(), 0, 16); -template -bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, - Progress progress, ContentReceiverCore receiver) { + while (chunk_len > 0) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } - ContentReceiverCore out = [&](const char *buf, size_t n) { - return receiver(buf, n); - }; + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + + chunk_len = std::stoi(line_reader.ptr(), 0, 16); + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; + } + + inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); + } + + template + bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; #ifdef CPPHTTPLIB_ZLIB_SUPPORT - detail::decompressor decompressor; + detail::decompressor decompressor; if (!decompressor.is_valid()) { status = 500; @@ -1533,471 +1714,581 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, }; } #else - if (x.get_header_value("Content-Encoding") == "gzip") { - status = 415; - return false; - } + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } #endif - auto ret = true; - auto exceed_payload_max_length = false; + auto ret = true; + auto exceed_payload_max_length = false; - if (is_chunked_transfer_encoding(x.headers)) { - ret = read_content_chunked(strm, out); - } else if (!has_header(x.headers, "Content-Length")) { - ret = read_content_without_length(strm, out); - } else { - auto len = get_header_value_uint64(x.headers, "Content-Length", 0); - if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; - } else if (len > 0) { - ret = read_content_with_length(strm, len, progress, out); + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; } - } - if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + template + inline int write_headers(Stream &strm, const T &info, const Headers &headers) { + auto write_len = 0; + for (const auto &x : info.headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; + } - return ret; -} + inline ssize_t write_content(Stream &strm, + ContentProviderWithCloser content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + content_provider( + offset, end_offset - offset, + [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }, + [&](void) { written_length = -1; }); + if (written_length < 0) { return written_length; } + } + return static_cast(offset - begin_offset); + } -template -inline int write_headers(Stream &strm, const T &info, const Headers &headers) { - auto write_len = 0; - for (const auto &x : info.headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } - for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } - auto len = strm.write("\r\n"); - if (len < 0) { return len; } - write_len += len; - return write_len; -} + inline ssize_t + write_content_chunked(Stream &strm, + ContentProviderWithCloser content_provider) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available) { + ssize_t written_length = 0; + content_provider( + offset, 0, + [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; -inline ssize_t write_content(Stream &strm, ContentProvider content_provider, - size_t offset, size_t length) { - size_t begin_offset = offset; - size_t end_offset = offset + length; - while (offset < end_offset) { - ssize_t written_length = 0; - content_provider( - offset, end_offset - offset, - [&](const char *d, size_t l) { - offset += l; - written_length = strm.write(d, l); - }, - [&](void) { written_length = -1; }); - if (written_length < 0) { return written_length; } - } - return static_cast(offset - begin_offset); -} + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }, + [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }); -inline ssize_t write_content_chunked(Stream &strm, - ContentProvider content_provider) { - size_t offset = 0; - auto data_available = true; - ssize_t total_written_length = 0; - while (data_available) { - ssize_t written_length = 0; - content_provider( - offset, 0, - [&](const char *d, size_t l) { - data_available = l > 0; - offset += l; + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; + } - // Emit chunked response header and footer for each chunk - auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; - written_length = strm.write(chunk); - }, - [&](void) { - data_available = false; - written_length = strm.write("0\r\n\r\n"); + template + inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req; + new_req.method = req.method; + new_req.path = path; + new_req.headers = req.headers; + new_req.body = req.body; + new_req.redirect_count = req.redirect_count - 1; + new_req.response_handler = req.response_handler; + new_req.content_receiver = req.content_receiver; + new_req.progress = req.progress; + + Response new_res; + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; + } + + inline std::string encode_url(const std::string &s) { + std::string result; + + for (auto i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + case ':': result += "%3A"; break; + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, len); + } else { + result += s[i]; + } + break; + } + } + + return result; + } + + inline std::string decode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; + } + + inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b, const char *e) { + if (key.empty()) { + key.assign(b, e); + } else { + val.assign(b, e); + } }); - - if (written_length < 0) { return written_length; } - total_written_length += written_length; - } - return total_written_length; -} - -template -inline bool redirect(T &cli, const Request &req, Response &res, - const std::string &path) { - Request new_req; - new_req.method = req.method; - new_req.path = path; - new_req.headers = req.headers; - new_req.body = req.body; - new_req.redirect_count = req.redirect_count - 1; - new_req.response_handler = req.response_handler; - new_req.content_receiver = req.content_receiver; - new_req.progress = req.progress; - - Response new_res; - auto ret = cli.send(new_req, new_res); - if (ret) { res = new_res; } - return ret; -} - -inline std::string encode_url(const std::string &s) { - std::string result; - - for (auto i = 0; s[i]; i++) { - switch (s[i]) { - case ' ': result += "%20"; break; - case '+': result += "%2B"; break; - case '\r': result += "%0D"; break; - case '\n': result += "%0A"; break; - case '\'': result += "%27"; break; - case ',': result += "%2C"; break; - case ':': result += "%3A"; break; - case ';': result += "%3B"; break; - default: - auto c = static_cast(s[i]); - if (c >= 0x80) { - result += '%'; - char hex[4]; - size_t len = snprintf(hex, sizeof(hex) - 1, "%02X", c); - assert(len == 2); - result.append(hex, len); - } else { - result += s[i]; - } - break; - } - } - - return result; -} - -inline std::string decode_url(const std::string &s) { - std::string result; - - for (size_t i = 0; i < s.size(); i++) { - if (s[i] == '%' && i + 1 < s.size()) { - if (s[i + 1] == 'u') { - int val = 0; - if (from_hex_to_i(s, i + 2, 4, val)) { - // 4 digits Unicode codes - char buff[4]; - size_t len = to_utf8(val, buff); - if (len > 0) { result.append(buff, len); } - i += 5; // 'u0000' - } else { - result += s[i]; - } - } else { - int val = 0; - if (from_hex_to_i(s, i + 1, 2, val)) { - // 2 digits hex codes - result += static_cast(val); - i += 2; // '00' - } else { - result += s[i]; - } - } - } else if (s[i] == '+') { - result += ' '; - } else { - result += s[i]; - } - } - - return result; -} - -inline void parse_query_text(const std::string &s, Params ¶ms) { - split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { - std::string key; - std::string val; - split(b, e, '=', [&](const char *b, const char *e) { - if (key.empty()) { - key.assign(b, e); - } else { - val.assign(b, e); - } - }); - params.emplace(key, decode_url(val)); - }); -} - -inline bool parse_multipart_boundary(const std::string &content_type, - std::string &boundary) { - auto pos = content_type.find("boundary="); - if (pos == std::string::npos) { return false; } - - boundary = content_type.substr(pos + 9); - return true; -} - -inline bool parse_multipart_formdata(const std::string &boundary, - const std::string &body, - MultipartFiles &files) { - static std::string dash = "--"; - static std::string crlf = "\r\n"; - - static std::regex re_content_type("Content-Type: (.*?)", - std::regex_constants::icase); - - static std::regex re_content_disposition( - "Content-Disposition: form-data; name=\"(.*?)\"(?:; filename=\"(.*?)\")?", - std::regex_constants::icase); - - auto dash_boundary = dash + boundary; - - auto pos = body.find(dash_boundary); - if (pos != 0) { return false; } - - pos += dash_boundary.size(); - - auto next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { return false; } - - pos = next_pos + crlf.size(); - - while (pos < body.size()) { - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { return false; } - - std::string name; - MultipartFile file; - - auto header = body.substr(pos, (next_pos - pos)); - - while (pos != next_pos) { - std::smatch m; - if (std::regex_match(header, m, re_content_type)) { - file.content_type = m[1]; - } else if (std::regex_match(header, m, re_content_disposition)) { - name = m[1]; - file.filename = m[2]; - } - - pos = next_pos + crlf.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { return false; } - - header = body.substr(pos, (next_pos - pos)); + params.emplace(key, decode_url(val)); + }); } - pos = next_pos + crlf.size(); + inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } - next_pos = body.find(crlf + dash_boundary, pos); - - if (next_pos == std::string::npos) { return false; } - - file.offset = pos; - file.length = next_pos - pos; - - pos = next_pos + crlf.size() + dash_boundary.size(); - - next_pos = body.find(crlf, pos); - if (next_pos == std::string::npos) { return false; } - - files.emplace(name, file); - - pos = next_pos + crlf.size(); - } - - return true; -} - -inline bool parse_range_header(const std::string &s, Ranges &ranges) { - try { - static auto re = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re)) { - auto pos = m.position(1); - auto len = m.length(1); - detail::split(&s[pos], &s[pos + len], ',', - [&](const char *b, const char *e) { - static auto re = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch m; - if (std::regex_match(b, e, m, re)) { - ssize_t first = -1; - if (!m.str(1).empty()) { - first = static_cast(std::stoll(m.str(1))); - } - - ssize_t last = -1; - if (!m.str(2).empty()) { - last = static_cast(std::stoll(m.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - throw std::runtime_error("invalid range error"); - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); + boundary = content_type.substr(pos + 9); return true; } - return false; - } catch (...) { return false; } -} -inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; -} + inline bool parse_range_header(const std::string &s, Ranges &ranges) { + static auto re_first_range = + std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = m.position(1); + auto len = m.length(1); + bool all_valid_ranges = true; + detail::split( + &s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re_another_range)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast(std::stoll(m.str(1))); + } -inline std::string make_multipart_data_boundary() { - static const char data[] = - "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast(std::stoll(m.str(2))); + } - std::random_device seed_gen; - std::mt19937 engine(seed_gen()); - - std::string result = "--cpp-httplib-multipart-data-"; - - for (auto i = 0; i < 16; i++) { - result += data[engine() % (sizeof(data) - 1)]; - } - - return result; -} - -inline std::pair -get_range_offset_and_length(const Request &req, size_t content_length, - size_t index) { - auto r = req.ranges[index]; - - if (r.first == -1 && r.second == -1) { - return std::make_pair(0, content_length); - } - - if (r.first == -1) { - r.first = content_length - r.second; - r.second = content_length - 1; - } - - if (r.second == -1) { r.second = content_length - 1; } - - return std::make_pair(r.first, r.second - r.first + 1); -} - -inline std::string make_content_range_header_field(size_t offset, size_t length, - size_t content_length) { - std::string field = "bytes "; - field += std::to_string(offset); - field += "-"; - field += std::to_string(offset + length - 1); - field += "/"; - field += std::to_string(content_length); - return field; -} - -template -bool process_multipart_ranges_data(const Request &req, Response &res, - const std::string &boundary, - const std::string &content_type, - SToken stoken, CToken ctoken, - Content content) { - for (size_t i = 0; i < req.ranges.size(); i++) { - ctoken("--"); - stoken(boundary); - ctoken("\r\n"); - if (!content_type.empty()) { - ctoken("Content-Type: "); - stoken(content_type); - ctoken("\r\n"); + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; } - auto offsets = detail::get_range_offset_and_length(req, res.body.size(), i); - auto offset = offsets.first; - auto length = offsets.second; + class MultipartFormDataParser { + public: + MultipartFormDataParser() {} - ctoken("Content-Range: "); - stoken(make_content_range_header_field(offset, length, res.body.size())); - ctoken("\r\n"); - ctoken("\r\n"); - if (!content(offset, length)) { return false; } - ctoken("\r\n"); - } + void set_boundary(const std::string &boundary) { + boundary_ = boundary; + } - ctoken("--"); - stoken(boundary); - ctoken("--\r\n"); + bool is_valid() const { return is_valid_; } - return true; -} + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); -inline std::string make_multipart_ranges_data(const Request &req, Response &res, - const std::string &boundary, - const std::string &content_type) { - std::string data; + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + if (pos == 0) { + if (!header_callback(name_, file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + name_ = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { + pos = buf_.size(); + } + if (!content_callback(name_, buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(name_, buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(name_, buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data += token; }, - [&](const char *token) { data += token; }, - [&](size_t offset, size_t length) { - data += res.body.substr(offset, length); return true; - }); + } - return data; -} + private: + void clear_file_info() { + name_.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } -inline size_t -get_multipart_ranges_data_length(const Request &req, Response &res, - const std::string &boundary, - const std::string &content_type) { - size_t data_length = 0; + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; - process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { data_length += token.size(); }, - [&](const char *token) { data_length += strlen(token); }, - [&](size_t /*offset*/, size_t length) { - data_length += length; + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + std::string name_; + MultipartFile file_; + }; + + inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; + } + + inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; + } + + inline std::pair + get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + if (r.first == -1) { + r.first = content_length - r.second; + r.second = content_length - 1; + } + + if (r.second == -1) { r.second = content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); + } + + inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; + } + + template + bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; + } + + inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; + } + + inline size_t + get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; + } + + inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return detail::write_content(strm, res.content_provider, offset, + length) >= 0; + }); + } + + inline std::pair + get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { r.second = res.content_length - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); + } + + inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI") { return true; - }); - - return data_length; -} - -inline bool write_multipart_ranges_data(Stream &strm, const Request &req, - Response &res, - const std::string &boundary, - const std::string &content_type) { - return process_multipart_ranges_data( - req, res, boundary, content_type, - [&](const std::string &token) { strm.write(token); }, - [&](const char *token) { strm.write(token); }, - [&](size_t offset, size_t length) { - return detail::write_content(strm, res.content_provider, offset, - length) >= 0; - }); -} - -inline std::pair -get_range_offset_and_length(const Request &req, const Response &res, - size_t index) { - auto r = req.ranges[index]; - - if (r.second == -1) { r.second = res.content_provider_resource_length - 1; } - - return std::make_pair(r.first, r.second - r.first + 1); -} + } + // TODO: check if Content-Length is set + return false; + } #ifdef _WIN32 -class WSInit { + class WSInit { public: WSInit() { WSADATA wsaData; @@ -2010,433 +2301,473 @@ public: static WSInit wsinit_; #endif -} // namespace detail + } // namespace detail // Header utilities -inline std::pair make_range_header(Ranges ranges) { - std::string field = "bytes="; - auto i = 0; - for (auto r : ranges) { - if (i != 0) { field += ", "; } - if (r.first != -1) { field += std::to_string(r.first); } - field += '-'; - if (r.second != -1) { field += std::to_string(r.second); } - i++; + inline std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); } - return std::make_pair("Range", field); -} -inline std::pair -make_basic_authentication_header(const std::string &username, - const std::string &password) { - auto field = "Basic " + detail::base64_encode(username + ":" + password); - return std::make_pair("Authorization", field); -} + inline std::pair + make_basic_authentication_header(const std::string &username, + const std::string &password) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + return std::make_pair("Authorization", field); + } // Request implementation -inline bool Request::has_header(const char *key) const { - return detail::has_header(headers, key); -} + inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); + } -inline std::string Request::get_header_value(const char *key, size_t id) const { - return detail::get_header_value(headers, key, id, ""); -} + inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); + } -inline size_t Request::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return std::distance(r.first, r.second); -} + inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); + } -inline void Request::set_header(const char *key, const char *val) { - headers.emplace(key, val); -} + inline void Request::set_header(const char *key, const char *val) { + headers.emplace(key, val); + } -inline void Request::set_header(const char *key, const std::string &val) { - headers.emplace(key, val); -} + inline void Request::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); + } -inline bool Request::has_param(const char *key) const { - return params.find(key) != params.end(); -} + inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); + } -inline std::string Request::get_param_value(const char *key, size_t id) const { - auto it = params.find(key); - std::advance(it, id); - if (it != params.end()) { return it->second; } - return std::string(); -} + inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, id); + if (it != params.end()) { return it->second; } + return std::string(); + } -inline size_t Request::get_param_value_count(const char *key) const { - auto r = params.equal_range(key); - return std::distance(r.first, r.second); -} + inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return std::distance(r.first, r.second); + } -inline bool Request::has_file(const char *key) const { - return files.find(key) != files.end(); -} + inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); + } -inline MultipartFile Request::get_file_value(const char *key) const { - auto it = files.find(key); - if (it != files.end()) { return it->second; } - return MultipartFile(); -} + inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); + } + + inline MultipartFile Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFile(); + } // Response implementation -inline bool Response::has_header(const char *key) const { - return headers.find(key) != headers.end(); -} + inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); + } -inline std::string Response::get_header_value(const char *key, - size_t id) const { - return detail::get_header_value(headers, key, id, ""); -} + inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); + } -inline size_t Response::get_header_value_count(const char *key) const { - auto r = headers.equal_range(key); - return std::distance(r.first, r.second); -} + inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return std::distance(r.first, r.second); + } -inline void Response::set_header(const char *key, const char *val) { - headers.emplace(key, val); -} + inline void Response::set_header(const char *key, const char *val) { + headers.emplace(key, val); + } -inline void Response::set_header(const char *key, const std::string &val) { - headers.emplace(key, val); -} + inline void Response::set_header(const char *key, const std::string &val) { + headers.emplace(key, val); + } -inline void Response::set_redirect(const char *url) { - set_header("Location", url); - status = 302; -} + inline void Response::set_redirect(const char *url) { + set_header("Location", url); + status = 302; + } -inline void Response::set_content(const char *s, size_t n, - const char *content_type) { - body.assign(s, n); - set_header("Content-Type", content_type); -} + inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); + } -inline void Response::set_content(const std::string &s, - const char *content_type) { - body = s; - set_header("Content-Type", content_type); -} + inline void Response::set_content(const std::string &s, + const char *content_type) { + body = s; + set_header("Content-Type", content_type); + } -inline void Response::set_content_provider( + inline void Response::set_content_provider( size_t length, std::function provider, std::function resource_releaser) { - assert(length > 0); - content_provider_resource_length = length; - content_provider = [provider](size_t offset, size_t length, DataSink sink, - Done) { provider(offset, length, sink); }; - content_provider_resource_releaser = resource_releaser; -} + assert(length > 0); + content_length = length; + content_provider = [provider](size_t offset, size_t length, DataSink sink, + Done) { provider(offset, length, sink); }; + content_provider_resource_releaser = resource_releaser; + } -inline void Response::set_chunked_content_provider( + inline void Response::set_chunked_content_provider( std::function provider, std::function resource_releaser) { - content_provider_resource_length = 0; - content_provider = [provider](size_t offset, size_t, DataSink sink, - Done done) { provider(offset, sink, done); }; - content_provider_resource_releaser = resource_releaser; -} + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink sink, + Done done) { provider(offset, sink, done); }; + content_provider_resource_releaser = resource_releaser; + } // Rstream implementation -template -inline int Stream::write_format(const char *fmt, const Args &... args) { - const auto bufsiz = 2048; - char buf[bufsiz]; + template + inline int Stream::write_format(const char *fmt, const Args &... args) { + std::array buf; #if defined(_MSC_VER) && _MSC_VER < 1900 - auto n = _snprintf_s(buf, bufsiz, bufsiz - 1, fmt, args...); + auto n = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); #else - auto n = snprintf(buf, bufsiz - 1, fmt, args...); + auto n = snprintf(buf.data(), buf.size() - 1, fmt, args...); #endif - if (n <= 0) { return n; } + if (n <= 0) { return n; } - if (n >= bufsiz - 1) { - std::vector glowable_buf(bufsiz); + if (n >= static_cast(buf.size()) - 1) { + std::vector glowable_buf(buf.size()); - while (n >= static_cast(glowable_buf.size() - 1)) { - glowable_buf.resize(glowable_buf.size() * 2); + while (n >= static_cast(glowable_buf.size() - 1)) { + glowable_buf.resize(glowable_buf.size() * 2); #if defined(_MSC_VER) && _MSC_VER < 1900 - n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), + n = _snprintf_s(&glowable_buf[0], glowable_buf.size(), glowable_buf.size() - 1, fmt, args...); #else - n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); + n = snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...); #endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); } - return write(&glowable_buf[0], n); - } else { - return write(buf, n); } -} // Socket stream implementation -inline SocketStream::SocketStream(socket_t sock) : sock_(sock) {} + inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} -inline SocketStream::~SocketStream() {} + inline SocketStream::~SocketStream() {} -inline int SocketStream::read(char *ptr, size_t size) { - if (detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, - CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) { - return recv(sock_, ptr, static_cast(size), 0); + inline int SocketStream::read(char *ptr, size_t size) { + if (detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return recv(sock_, ptr, static_cast(size), 0); + } + return -1; } - return -1; -} -inline int SocketStream::write(const char *ptr, size_t size) { - return send(sock_, ptr, static_cast(size), 0); -} + inline int SocketStream::write(const char *ptr, size_t size) { + return send(sock_, ptr, static_cast(size), 0); + } -inline int SocketStream::write(const char *ptr) { - return write(ptr, strlen(ptr)); -} + inline int SocketStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); + } -inline int SocketStream::write(const std::string &s) { - return write(s.data(), s.size()); -} + inline int SocketStream::write(const std::string &s) { + return write(s.data(), s.size()); + } -inline std::string SocketStream::get_remote_addr() const { - return detail::get_remote_addr(sock_); -} + inline std::string SocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); + } // Buffer stream implementation -inline int BufferStream::read(char *ptr, size_t size) { + inline int BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1900 - return static_cast(buffer._Copy_s(ptr, size, size)); + return static_cast(buffer._Copy_s(ptr, size, size)); #else - return static_cast(buffer.copy(ptr, size)); + return static_cast(buffer.copy(ptr, size)); #endif -} + } -inline int BufferStream::write(const char *ptr, size_t size) { - buffer.append(ptr, size); - return static_cast(size); -} + inline int BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); + } -inline int BufferStream::write(const char *ptr) { - return write(ptr, strlen(ptr)); -} + inline int BufferStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); + } -inline int BufferStream::write(const std::string &s) { - return write(s.data(), s.size()); -} + inline int BufferStream::write(const std::string &s) { + return write(s.data(), s.size()); + } -inline std::string BufferStream::get_remote_addr() const { return ""; } + inline std::string BufferStream::get_remote_addr() const { return ""; } -inline const std::string &BufferStream::get_buffer() const { return buffer; } + inline const std::string &BufferStream::get_buffer() const { return buffer; } // HTTP server implementation -inline Server::Server() + inline Server::Server() : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), svr_sock_(INVALID_SOCKET) { #ifndef _WIN32 - signal(SIGPIPE, SIG_IGN); + signal(SIGPIPE, SIG_IGN); #endif - new_task_queue = [] { + new_task_queue = [] { #if CPPHTTPLIB_THREAD_POOL_COUNT > 0 - return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); +#elif CPPHTTPLIB_THREAD_POOL_COUNT == 0 + return new Threads(); #else - return new Threads(); + return new NoThread(); #endif - }; -} - -inline Server::~Server() {} - -inline Server &Server::Get(const char *pattern, Handler handler) { - get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline Server &Server::Post(const char *pattern, Handler handler) { - post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline Server &Server::Put(const char *pattern, Handler handler) { - put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline Server &Server::Patch(const char *pattern, Handler handler) { - patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline Server &Server::Delete(const char *pattern, Handler handler) { - delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline Server &Server::Options(const char *pattern, Handler handler) { - options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); - return *this; -} - -inline bool Server::set_base_dir(const char *path) { - if (detail::is_dir(path)) { - base_dir_ = path; - return true; - } - return false; -} - -inline void Server::set_file_request_handler(Handler handler) { - file_request_handler_ = handler; -} - -inline void Server::set_error_handler(Handler handler) { - error_handler_ = handler; -} - -inline void Server::set_logger(Logger logger) { logger_ = logger; } - -inline void Server::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; -} - -inline void Server::set_payload_max_length(size_t length) { - payload_max_length_ = length; -} - -inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { - if (bind_internal(host, port, socket_flags) < 0) return false; - return true; -} -inline int Server::bind_to_any_port(const char *host, int socket_flags) { - return bind_internal(host, 0, socket_flags); -} - -inline bool Server::listen_after_bind() { return listen_internal(); } - -inline bool Server::listen(const char *host, int port, int socket_flags) { - return bind_to_port(host, port, socket_flags) && listen_internal(); -} - -inline bool Server::is_running() const { return is_running_; } - -inline void Server::stop() { - if (is_running_) { - assert(svr_sock_ != INVALID_SOCKET); - std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); - detail::shutdown_socket(sock); - detail::close_socket(sock); - } -} - -inline bool Server::parse_request_line(const char *s, Request &req) { - static std::regex re("(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " - "(([^?]+)(?:\\?(.+?))?) (HTTP/1\\.[01])\r\n"); - - std::cmatch m; - if (std::regex_match(s, m, re)) { - req.version = std::string(m[5]); - req.method = std::string(m[1]); - req.target = std::string(m[2]); - req.path = detail::decode_url(m[3]); - - // Parse query text - auto len = std::distance(m[4].first, m[4].second); - if (len > 0) { detail::parse_query_text(m[4], req.params); } - - return true; + }; } - return false; -} + inline Server::~Server() {} -inline bool Server::write_response(Stream &strm, bool last_connection, - const Request &req, Response &res) { - assert(res.status != -1); + inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } - if (400 <= res.status && error_handler_) { error_handler_(req, res); } + inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } - // Response line - if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, - detail::status_message(res.status))) { + inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; + } + + inline bool Server::set_base_dir(const char *dir, const char *mount_point) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } + } return false; } - // Headers - if (last_connection || req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); + inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); } - if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { - res.set_header("Connection", "Keep-Alive"); + inline void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); } - if (!res.has_header("Content-Type")) { - res.set_header("Content-Type", "text/plain"); + inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + + inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; } - if (!res.has_header("Accept-Ranges")) { - res.set_header("Accept-Ranges", "bytes"); + inline void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; } - std::string content_type; - std::string boundary; + inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + } - if (req.ranges.size() > 1) { - boundary = detail::make_multipart_data_boundary(); + inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; + } + inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); + } - auto it = res.headers.find("Content-Type"); - if (it != res.headers.end()) { - content_type = it->second; - res.headers.erase(it); + inline bool Server::listen_after_bind() { return listen_internal(); } + + inline bool Server::listen(const char *host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); + } + + inline bool Server::is_running() const { return is_running_; } + + inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + inline bool Server::parse_request_line(const char *s, Request &req) { + static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3]); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; } - res.headers.emplace("Content-Type", - "multipart/byteranges; boundary=" + boundary); + return false; } - if (res.body.empty()) { - if (res.content_provider_resource_length > 0) { - size_t length = 0; + inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + // Response line + if (!strm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } + + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } + + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } + + if (!res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { if (req.ranges.empty()) { - length = res.content_provider_resource_length; + ; } else if (req.ranges.size() == 1) { - auto offsets = detail::get_range_offset_and_length( - req, res.content_provider_resource_length, 0); - auto offset = offsets.first; - length = offsets.second; - auto content_range = detail::make_content_range_header_field( - offset, length, res.content_provider_resource_length); - res.set_header("Content-Range", content_range); - } else { - length = detail::get_multipart_ranges_data_length(req, res, boundary, - content_type); - } - res.set_header("Content-Length", std::to_string(length)); - } else { - if (res.content_provider) { - res.set_header("Transfer-Encoding", "chunked"); - } else { - res.set_header("Content-Length", "0"); - } - } - } else { - if (req.ranges.empty()) { - ; - } else if (req.ranges.size() == 1) { - auto offsets = + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), 0); - auto offset = offsets.first; - auto length = offsets.second; - auto content_range = detail::make_content_range_header_field( + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( offset, length, res.body.size()); - res.set_header("Content-Range", content_range); - res.body = res.body.substr(offset, length); - } else { - res.body = + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = detail::make_multipart_ranges_data(req, res, boundary, content_type); - } + } #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 + // TODO: 'Accpet-Encoding' has gzip, not gzip;q=0 const auto &encodings = req.get_header_value("Accept-Encoding"); if (encodings.find("gzip") != std::string::npos && detail::can_compress(res.get_header_value("Content-Type"))) { @@ -2446,85 +2777,174 @@ inline bool Server::write_response(Stream &strm, bool last_connection, } #endif - auto length = std::to_string(res.body.size()); - res.set_header("Content-Length", length); - } + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } - if (!detail::write_headers(strm, res, Headers())) { return false; } + if (!detail::write_headers(strm, res, Headers())) { return false; } - // Body - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!strm.write(res.body)) { return false; } - } else if (res.content_provider) { - if (!write_content_with_provider(strm, req, res, boundary, - content_type)) { - return false; + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } } } + + // Log + if (logger_) { logger_(req, res); } + + return true; } - // Log - if (logger_) { logger_(req, res); } - - return true; -} - -inline bool -Server::write_content_with_provider(Stream &strm, const Request &req, - Response &res, const std::string &boundary, - const std::string &content_type) { - if (res.content_provider_resource_length) { - if (req.ranges.empty()) { - if (detail::write_content(strm, res.content_provider, 0, - res.content_provider_resource_length) < 0) { - return false; - } - } else if (req.ranges.size() == 1) { - auto offsets = detail::get_range_offset_and_length( - req, res.content_provider_resource_length, 0); - auto offset = offsets.first; - auto length = offsets.second; - if (detail::write_content(strm, res.content_provider, offset, length) < - 0) { - return false; + inline bool + Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } } } else { - if (!detail::write_multipart_ranges_data(strm, req, res, boundary, - content_type)) { + if (detail::write_content_chunked(strm, res.content_provider) < 0) { return false; } } - } else { - if (detail::write_content_chunked(strm, res.content_provider) < 0) { - return false; - } - } - return true; -} - -inline bool Server::handle_file_request(Request &req, Response &res) { - if (!base_dir_.empty() && detail::is_valid_path(req.path)) { - std::string path = base_dir_ + req.path; - - if (!path.empty() && path.back() == '/') { path += "index.html"; } - - if (detail::is_file(path)) { - detail::read_file(path, res.body); - auto type = detail::find_content_type(path); - if (type) { res.set_header("Content-Type", type); } - res.status = 200; - if (file_request_handler_) { file_request_handler_(req, res); } - return true; - } + return true; } - return false; -} + inline bool Server::read_content(Stream &strm, bool last_connection, + Request &req, Response &res) { + auto ret = read_content_core(strm, last_connection, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const std::string &name, const MultipartFile &file) { + req.files.emplace(name, file); + return true; + }, + [&](const std::string &name, const char *buf, size_t n) { + // TODO: handle elements with a same key + auto it = req.files.find(name); + auto &content = it->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + } + ); -inline socket_t Server::create_server_socket(const char *host, int port, - int socket_flags) const { - return detail::create_socket( + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + + return ret; + } + + inline bool + Server::read_content_with_content_receiver(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + MultipartContentReceiver multipart_receiver) { + return read_content_core(strm, last_connection, req, res, + receiver, multipart_header, multipart_receiver); + } + + inline bool + Server::read_content_core(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + MultipartContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + multipart_form_data_parser.set_boundary(boundary); + out = [&](const char *buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, mulitpart_header); + }; + } else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return write_response(strm, last_connection, req, res); + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + + return true; + } + + inline bool Server::handle_file_request(Request &req, Response &res) { + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = detail::find_content_type(path); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (file_request_handler_) { file_request_handler_(req, res); } + return true; + } + } + } + } + return false; + } + + inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( host, port, [](socket_t sock, struct addrinfo &ai) -> bool { if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { @@ -2536,831 +2956,936 @@ inline socket_t Server::create_server_socket(const char *host, int port, return true; }, socket_flags); -} - -inline int Server::bind_internal(const char *host, int port, int socket_flags) { - if (!is_valid()) { return -1; } - - svr_sock_ = create_server_socket(host, port, socket_flags); - if (svr_sock_ == INVALID_SOCKET) { return -1; } - - if (port == 0) { - struct sockaddr_storage address; - socklen_t len = sizeof(address); - if (getsockname(svr_sock_, reinterpret_cast(&address), - &len) == -1) { - return -1; - } - if (address.ss_family == AF_INET) { - return ntohs(reinterpret_cast(&address)->sin_port); - } else if (address.ss_family == AF_INET6) { - return ntohs( - reinterpret_cast(&address)->sin6_port); - } else { - return -1; - } - } else { - return port; } -} -inline bool Server::listen_internal() { - auto ret = true; - is_running_ = true; + inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } - { - std::unique_ptr task_queue(new_task_queue()); + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } - for (;;) { - if (svr_sock_ == INVALID_SOCKET) { - // The server socket was closed by 'stop' method. - break; + if (port == 0) { + struct sockaddr_storage address; + socklen_t len = sizeof(address); + if (getsockname(svr_sock_, reinterpret_cast(&address), + &len) == -1) { + return -1; } - - auto val = detail::select_read(svr_sock_, 0, 100000); - - if (val == 0) { // Timeout - continue; + if (address.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&address)->sin_port); + } else if (address.ss_family == AF_INET6) { + return ntohs( + reinterpret_cast(&address)->sin6_port); + } else { + return -1; } + } else { + return port; + } + } - socket_t sock = accept(svr_sock_, nullptr, nullptr); + inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; - if (sock == INVALID_SOCKET) { - if (errno == EMFILE) { - // The per-process limit of open file descriptors has been reached. - // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + { + std::unique_ptr task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout continue; } - if (svr_sock_ != INVALID_SOCKET) { - detail::close_socket(svr_sock_); - ret = false; - } else { - ; // The server socket was closed by user. + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; } - break; + + task_queue->enqueue([=]() { process_and_close_socket(sock); }); } - task_queue->enqueue([=]() { process_and_close_socket(sock); }); + task_queue->shutdown(); } - task_queue->shutdown(); + is_running_ = false; + return ret; } - is_running_ = false; - return ret; -} + inline bool Server::routing(Request &req, Response &res, Stream &strm, + bool last_connection) { + // File handler + if (req.method == "GET" && handle_file_request(req, res)) { return true; } -inline bool Server::routing(Request &req, Response &res) { - if (req.method == "GET" && handle_file_request(req, res)) { return true; } + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, last_connection, req, res, + receiver, nullptr, nullptr); + }, + [&](MultipartContentHeader header, MultipartContentReceiver receiver) { + return read_content_with_content_receiver(strm, last_connection, req, res, + nullptr, header, receiver); + } + ); - if (req.method == "GET" || req.method == "HEAD") { - return dispatch_request(req, res, get_handlers_); - } else if (req.method == "POST") { - return dispatch_request(req, res, post_handlers_); - } else if (req.method == "PUT") { - return dispatch_request(req, res, put_handlers_); - } else if (req.method == "DELETE") { - return dispatch_request(req, res, delete_handlers_); - } else if (req.method == "OPTIONS") { - return dispatch_request(req, res, options_handlers_); - } else if (req.method == "PATCH") { - return dispatch_request(req, res, patch_handlers_); - } + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } + } - res.status = 400; - return false; -} - -inline bool Server::dispatch_request(Request &req, Response &res, - Handlers &handlers) { - for (const auto &x : handlers) { - const auto &pattern = x.first; - const auto &handler = x.second; - - if (std::regex_match(req.path, req.matches, pattern)) { - handler(req, res); - return true; + // Read content into `req.body` + if (!read_content(strm, last_connection, req, res)) { return false; } } - } - return false; -} -inline bool -Server::process_request(Stream &strm, bool last_connection, - bool &connection_close, - std::function setup_request) { - const auto bufsiz = 2048; - char buf[bufsiz]; + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } - detail::stream_line_reader reader(strm, buf, bufsiz); - - // Connection has been closed on client - if (!reader.getline()) { return false; } - - Request req; - Response res; - - res.version = "HTTP/1.1"; - - // Check if the request URI doesn't exceed the limit - if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = 414; - return write_response(strm, last_connection, req, res); - } - - // Request line and headers - if (!parse_request_line(reader.ptr(), req) || - !detail::read_headers(strm, req.headers)) { res.status = 400; - return write_response(strm, last_connection, req, res); + return false; } - if (req.get_header_value("Connection") == "close") { - connection_close = true; + inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + return false; } - if (req.version == "HTTP/1.0" && - req.get_header_value("Connection") != "Keep-Alive") { - connection_close = true; + inline bool + Server::dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; } - req.remoteAddr = strm.get_remote_addr(); - req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + inline bool + Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request) { + std::array buf{}; - // Body - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "PRI") { - if (!detail::read_content(strm, req, payload_max_length_, res.status, - Progress(), [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { - return false; - } - req.body.append(buf, n); - return true; - })) { + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; return write_response(strm, last_connection, req, res); } - const auto &content_type = req.get_header_value("Content-Type"); + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); - } else if (!content_type.find("multipart/form-data")) { - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary) || - !detail::parse_multipart_formdata(boundary, req.body, req.files)) { - res.status = 400; - return write_response(strm, last_connection, req, res); + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + req.set_header("REMOTE_ADDR", strm.get_remote_addr()); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error } } - } - if (req.has_header("Range")) { - const auto &range_header_value = req.get_header_value("Range"); - if (!detail::parse_range_header(range_header_value, req.ranges)) { - // TODO: error + if (setup_request) { setup_request(req); } + + // Rounting + if (routing(req, res, strm, last_connection)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } } + + return write_response(strm, last_connection, req, res); } - if (setup_request) { setup_request(req); } + inline bool Server::is_valid() const { return true; } - if (routing(req, res)) { - if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } - } else { - if (res.status == -1) { res.status = 404; } - } - - return write_response(strm, last_connection, req, res); -} - -inline bool Server::is_valid() const { return true; } - -inline bool Server::process_and_close_socket(socket_t sock) { - return detail::process_and_close_socket( - false, sock, keep_alive_max_count_, + inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, [this](Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close, nullptr); }); -} + } // HTTP client implementation -inline Client::Client(const char *host, int port, time_t timeout_sec) + inline Client::Client(const char *host, int port, time_t timeout_sec) : host_(host), port_(port), timeout_sec_(timeout_sec), host_and_port_(host_ + ":" + std::to_string(port_)), keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), follow_location_(false) {} -inline Client::~Client() {} + inline Client::~Client() {} -inline bool Client::is_valid() const { return true; } + inline bool Client::is_valid() const { return true; } -inline socket_t Client::create_client_socket() const { - return detail::create_socket( - host_.c_str(), port_, [=](socket_t sock, struct addrinfo &ai) -> bool { - detail::set_nonblocking(sock, true); - - auto ret = connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); - if (ret < 0) { - if (detail::is_connection_error() || - !detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) { - detail::close_socket(sock); - return false; - } - } - - detail::set_nonblocking(sock, false); - return true; - }); -} - -inline bool Client::read_response_line(Stream &strm, Response &res) { - const auto bufsiz = 2048; - char buf[bufsiz]; - - detail::stream_line_reader reader(strm, buf, bufsiz); - - if (!reader.getline()) { return false; } - - const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); - - std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { - res.version = std::string(m[1]); - res.status = std::stoi(std::string(m[2])); + inline socket_t Client::create_client_socket() const { + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_); } - return true; -} + inline bool Client::read_response_line(Stream &strm, Response &res) { + std::array buf; -inline bool Client::send(const Request &req, Response &res) { - if (req.path.empty()) { return false; } + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } + if (!line_reader.getline()) { return false; } - auto ret = process_and_close_socket( + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; + } + + inline bool Client::send(const Request &req, Response &res) { + if (req.path.empty()) { return false; } + + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } + + auto ret = process_and_close_socket( sock, 1, [&](Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, req, res, last_connection, connection_close); }); - if (ret && follow_location_ && (300 < res.status && res.status < 400)) { - ret = redirect(req, res); - } - - return ret; -} - -inline bool Client::send(const std::vector &requests, - std::vector &responses) { - size_t i = 0; - while (i < requests.size()) { - auto sock = create_client_socket(); - if (sock == INVALID_SOCKET) { return false; } - - if (!process_and_close_socket( - sock, requests.size() - i, - [&](Stream &strm, bool last_connection, bool &connection_close) -> bool { - auto &req = requests[i]; - auto res = Response(); - i++; - - if (req.path.empty()) { return false; } - auto ret = process_request(strm, req, res, last_connection, - connection_close); - - if (ret && follow_location_ && - (300 < res.status && res.status < 400)) { - ret = redirect(req, res); - } - - if (ret) { responses.emplace_back(std::move(res)); } - - return ret; - })) { - return false; + if (ret && follow_location_ && (300 < res.status && res.status < 400)) { + ret = redirect(req, res); } + + return ret; } - return true; -} + inline bool Client::send(const std::vector &requests, + std::vector &responses) { + size_t i = 0; + while (i < requests.size()) { + auto sock = create_client_socket(); + if (sock == INVALID_SOCKET) { return false; } -inline bool Client::redirect(const Request &req, Response &res) { - if (req.redirect_count == 0) { return false; } + if (!process_and_close_socket( + sock, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i]; + auto res = Response(); + i++; - auto location = res.get_header_value("location"); - if (location.empty()) { return false; } + if (req.path.empty()) { return false; } + auto ret = process_request(strm, req, res, last_connection, + connection_close); - std::regex re( + if (ret && follow_location_ && + (300 < res.status && res.status < 400)) { + ret = redirect(req, res); + } + + if (ret) { responses.emplace_back(std::move(res)); } + + return ret; + })) { + return false; + } + } + + return true; + } + + inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + std::regex re( R"(^(?:([^:/?#]+):)?(?://([^/?#]*))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); - auto scheme = is_ssl() ? "https" : "http"; + auto scheme = is_ssl() ? "https" : "http"; - std::smatch m; - if (regex_match(location, m, re)) { - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - auto next_path = m[3].str(); - if (next_host.empty()) { next_host = host_; } - if (next_path.empty()) { next_path = "/"; } + std::smatch m; + if (regex_match(location, m, re)) { + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto next_path = m[3].str(); + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } - if (next_scheme == scheme && next_host == host_) { - return detail::redirect(*this, req, res, next_path); - } else { - if (next_scheme == "https") { + if (next_scheme == scheme && next_host == host_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSLClient cli(next_host.c_str()); - cli.follow_location(true); - return detail::redirect(cli, req, res, next_path); + SSLClient cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); #else - return false; + return false; #endif - } else { - Client cli(next_host.c_str()); - cli.follow_location(true); - return detail::redirect(cli, req, res, next_path); + } else { + Client cli(next_host.c_str()); + cli.follow_location(true); + return detail::redirect(cli, req, res, next_path); + } } } - } - return false; -} - -inline void Client::write_request(Stream &strm, const Request &req, - bool last_connection) { - BufferStream bstrm; - - // Request line - auto path = detail::encode_url(req.path); - - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - - // Additonal headers - Headers headers; - if (last_connection) { headers.emplace("Connection", "close"); } - - if (!req.has_header("Host")) { - if (is_ssl()) { - if (port_ == 443) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } - } else { - if (port_ == 80) { - headers.emplace("Host", host_); - } else { - headers.emplace("Host", host_and_port_); - } - } - } - - if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } - - if (!req.has_header("User-Agent")) { - headers.emplace("User-Agent", "cpp-httplib/0.2"); - } - - if (req.body.empty()) { - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - headers.emplace("Content-Length", "0"); - } - } else { - if (!req.has_header("Content-Type")) { - headers.emplace("Content-Type", "text/plain"); - } - - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - headers.emplace("Content-Length", length); - } - } - - detail::write_headers(bstrm, req, headers); - - // Body - if (!req.body.empty()) { bstrm.write(req.body); } - - // Flush buffer - auto &data = bstrm.get_buffer(); - strm.write(data.data(), data.size()); -} - -inline bool Client::process_request(Stream &strm, const Request &req, - Response &res, bool last_connection, - bool &connection_close) { - // Send request - write_request(strm, req, last_connection); - - // Receive response and headers - if (!read_response_line(strm, res) || - !detail::read_headers(strm, res.headers)) { return false; } - if (res.get_header_value("Connection") == "close" || - res.version == "HTTP/1.0") { - connection_close = true; - } + inline void Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + BufferStream bstrm; - if (req.response_handler) { - if (!req.response_handler(res)) { return false; } - } + // Request line + auto path = detail::encode_url(req.path); - // Body - if (req.method != "HEAD") { - detail::ContentReceiverCore out = [&](const char *buf, size_t n) { - if (res.body.size() + n > res.body.max_size()) { return false; } - res.body.append(buf, n); - return true; - }; + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - if (req.content_receiver) { - auto offset = std::make_shared(); - auto length = get_header_value_uint64(res.headers, "Content-Length", 0); - auto receiver = req.content_receiver; - out = [offset, length, receiver](const char *buf, size_t n) { - auto ret = receiver(buf, n, *offset, length); - (*offset) += n; - return ret; - }; + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } } - int dummy_status; - if (!detail::read_content(strm, res, std::numeric_limits::max(), - dummy_status, req.progress, out)) { + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.2"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, + [&](const char *d, size_t l) { + auto written_length = strm.write(d, l); + offset += written_length; + }); + } + } + } else { + strm.write(req.body); + } + } + + inline std::shared_ptr Client::send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type, bool compress) { +#ifndef CPPHTTPLIB_ZLIB_SUPPORT + (void)compress; +#endif + + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + req.headers.emplace("Content-Type", content_type); + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress) { + if (content_provider) { + size_t offset = 0; + while (offset < content_length) { + content_provider(offset, content_length - offset, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }); + } + } else { + req.body = body; + } + + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } + } + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + write_request(strm, req, last_connection); + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { return false; } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD") { + ContentReceiver out = [&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }; + + if (req.content_receiver) { + out = [&](const char *buf, size_t n) { + return req.content_receiver(buf, n); + }; + } + + int dummy_status; + if (!detail::read_content(strm, res, std::numeric_limits::max(), + dummy_status, req.progress, out)) { + return false; + } + } + + return true; } - return true; -} - -inline bool Client::process_and_close_socket( + inline bool Client::process_and_close_socket( socket_t sock, size_t request_count, std::function - callback) { - request_count = std::min(request_count, keep_alive_max_count_); - return detail::process_and_close_socket(true, sock, request_count, callback); -} - -inline bool Client::is_ssl() const { return false; } - -inline std::shared_ptr Client::Get(const char *path) { - Progress dummy; - return Get(path, Headers(), dummy); -} - -inline std::shared_ptr Client::Get(const char *path, - Progress progress) { - return Get(path, Headers(), progress); -} - -inline std::shared_ptr Client::Get(const char *path, - const Headers &headers) { - Progress dummy; - return Get(path, headers, dummy); -} - -inline std::shared_ptr -Client::Get(const char *path, const Headers &headers, Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - req.progress = progress; - - auto res = std::make_shared(); - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Get(const char *path, - ContentReceiver content_receiver) { - Progress dummy; - return Get(path, Headers(), nullptr, content_receiver, dummy); -} - -inline std::shared_ptr Client::Get(const char *path, - ContentReceiver content_receiver, - Progress progress) { - return Get(path, Headers(), nullptr, content_receiver, progress); -} - -inline std::shared_ptr Client::Get(const char *path, - const Headers &headers, - ContentReceiver content_receiver) { - Progress dummy; - return Get(path, headers, nullptr, content_receiver, dummy); -} - -inline std::shared_ptr Client::Get(const char *path, - const Headers &headers, - ContentReceiver content_receiver, - Progress progress) { - return Get(path, headers, nullptr, content_receiver, progress); -} - -inline std::shared_ptr Client::Get(const char *path, - const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver) { - Progress dummy; - return Get(path, headers, response_handler, content_receiver, dummy); -} - -inline std::shared_ptr Client::Get(const char *path, - const Headers &headers, - ResponseHandler response_handler, - ContentReceiver content_receiver, - Progress progress) { - Request req; - req.method = "GET"; - req.path = path; - req.headers = headers; - req.response_handler = response_handler; - req.content_receiver = content_receiver; - req.progress = progress; - - auto res = std::make_shared(); - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Head(const char *path) { - return Head(path, Headers()); -} - -inline std::shared_ptr Client::Head(const char *path, - const Headers &headers) { - Request req; - req.method = "HEAD"; - req.headers = headers; - req.path = path; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Post(const char *path, - const std::string &body, - const char *content_type) { - return Post(path, Headers(), body, content_type); -} - -inline std::shared_ptr Client::Post(const char *path, - const Headers &headers, - const std::string &body, - const char *content_type) { - Request req; - req.method = "POST"; - req.headers = headers; - req.path = path; - - req.headers.emplace("Content-Type", content_type); - req.body = body; - - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Post(const char *path, - const Params ¶ms) { - return Post(path, Headers(), params); -} - -inline std::shared_ptr -Client::Post(const char *path, const Headers &headers, const Params ¶ms) { - std::string query; - for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } - query += it->first; - query += "="; - query += detail::encode_url(it->second); + callback) { + request_count = std::min(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); } - return Post(path, headers, query, "application/x-www-form-urlencoded"); -} + inline bool Client::is_ssl() const { return false; } -inline std::shared_ptr -Client::Post(const char *path, const MultipartFormDataItems &items) { - return Post(path, Headers(), items); -} - -inline std::shared_ptr -Client::Post(const char *path, const Headers &headers, - const MultipartFormDataItems &items) { - Request req; - req.method = "POST"; - req.headers = headers; - req.path = path; - - auto boundary = detail::make_multipart_data_boundary(); - - req.headers.emplace("Content-Type", - "multipart/form-data; boundary=" + boundary); - - for (const auto &item : items) { - req.body += "--" + boundary + "\r\n"; - req.body += "Content-Disposition: form-data; name=\"" + item.name + "\""; - if (!item.filename.empty()) { - req.body += "; filename=\"" + item.filename + "\""; - } - req.body += "\r\n"; - if (!item.content_type.empty()) { - req.body += "Content-Type: " + item.content_type + "\r\n"; - } - req.body += "\r\n"; - req.body += item.content + "\r\n"; + inline std::shared_ptr Client::Get(const char *path) { + Progress dummy; + return Get(path, Headers(), dummy); } - req.body += "--" + boundary + "--\r\n"; + inline std::shared_ptr Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); + } - auto res = std::make_shared(); + inline std::shared_ptr Client::Get(const char *path, + const Headers &headers) { + Progress dummy; + return Get(path, headers, dummy); + } - return send(req, *res) ? res : nullptr; -} + inline std::shared_ptr + Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); -inline std::shared_ptr Client::Put(const char *path, - const std::string &body, - const char *content_type) { - return Put(path, Headers(), body, content_type); -} + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; + } -inline std::shared_ptr Client::Put(const char *path, - const Headers &headers, - const std::string &body, - const char *content_type) { - Request req; - req.method = "PUT"; - req.headers = headers; - req.path = path; + inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, Headers(), nullptr, std::move(content_receiver), dummy); + } - req.headers.emplace("Content-Type", content_type); - req.body = body; + inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), progress); + } - auto res = std::make_shared(); - - return send(req, *res) ? res : nullptr; -} - -inline std::shared_ptr Client::Patch(const char *path, - const std::string &body, - const char *content_type) { - return Patch(path, Headers(), body, content_type); -} - -inline std::shared_ptr Client::Patch(const char *path, + inline std::shared_ptr Client::Get(const char *path, const Headers &headers, - const std::string &body, - const char *content_type) { - Request req; - req.method = "PATCH"; - req.headers = headers; - req.path = path; + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, headers, nullptr, std::move(content_receiver), dummy); + } - req.headers.emplace("Content-Type", content_type); - req.body = body; + inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), progress); + } - auto res = std::make_shared(); + inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + Progress dummy; + return Get(path, headers, std::move(response_handler), content_receiver, + dummy); + } - return send(req, *res) ? res : nullptr; -} + inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); -inline std::shared_ptr Client::Delete(const char *path) { - return Delete(path, Headers(), std::string(), nullptr); -} + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; + } -inline std::shared_ptr Client::Delete(const char *path, - const std::string &body, - const char *content_type) { - return Delete(path, Headers(), body, content_type); -} + inline std::shared_ptr Client::Head(const char *path) { + return Head(path, Headers()); + } -inline std::shared_ptr Client::Delete(const char *path, + inline std::shared_ptr Client::Head(const char *path, const Headers &headers) { - return Delete(path, headers, std::string(), nullptr); -} + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; -inline std::shared_ptr Client::Delete(const char *path, - const Headers &headers, + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + inline std::shared_ptr Client::Post(const char *path, const std::string &body, - const char *content_type) { - Request req; - req.method = "DELETE"; - req.headers = headers; - req.path = path; + const char *content_type, + bool compress) { + return Post(path, Headers(), body, content_type, compress); + } - if (content_type) { req.headers.emplace("Content-Type", content_type); } - req.body = body; + inline std::shared_ptr + Client::Post(const char *path, const Headers &headers, const std::string &body, + const char *content_type, bool compress) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type, compress); + } - auto res = std::make_shared(); + inline std::shared_ptr + Client::Post(const char *path, const Params ¶ms, bool compress) { + return Post(path, Headers(), params, compress); + } - return send(req, *res) ? res : nullptr; -} + inline std::shared_ptr Client::Post(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress) { + return Post(path, Headers(), content_length, content_provider, content_type, + compress); + } -inline std::shared_ptr Client::Options(const char *path) { - return Options(path, Headers()); -} + inline std::shared_ptr + Client::Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type, + bool compress) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type, compress); + } -inline std::shared_ptr Client::Options(const char *path, - const Headers &headers) { - Request req; - req.method = "OPTIONS"; - req.path = path; - req.headers = headers; + inline std::shared_ptr Client::Post(const char *path, + const Headers &headers, + const Params ¶ms, + bool compress) { + std::string query; + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } - auto res = std::make_shared(); + return Post(path, headers, query, "application/x-www-form-urlencoded", + compress); + } - return send(req, *res) ? res : nullptr; -} + inline std::shared_ptr + Client::Post(const char *path, const MultipartFormDataItems &items, + bool compress) { + return Post(path, Headers(), items, compress); + } -inline void Client::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; -} + inline std::shared_ptr + Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items, bool compress) { + auto boundary = detail::make_multipart_data_boundary(); -inline void Client::follow_location(bool on) { follow_location_ = on; } + std::string body; + + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str(), compress); + } + + inline std::shared_ptr Client::Put(const char *path, + const std::string &body, + const char *content_type, + bool compress) { + return Put(path, Headers(), body, content_type, compress); + } + + inline std::shared_ptr + Client::Put(const char *path, const Headers &headers, const std::string &body, + const char *content_type, bool compress) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type, compress); + } + + inline std::shared_ptr Client::Put(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress) { + return Put(path, Headers(), content_length, content_provider, content_type, + compress); + } + + inline std::shared_ptr + Client::Put(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type, + bool compress) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type, compress); + } + + inline std::shared_ptr Client::Patch(const char *path, + const std::string &body, + const char *content_type, + bool compress) { + return Patch(path, Headers(), body, content_type, compress); + } + + inline std::shared_ptr + Client::Patch(const char *path, const Headers &headers, const std::string &body, + const char *content_type, bool compress) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type, compress); + } + + inline std::shared_ptr Client::Patch(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type, + bool compress) { + return Patch(path, Headers(), content_length, content_provider, content_type, + compress); + } + + inline std::shared_ptr + Client::Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type, + bool compress) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type, compress); + } + + inline std::shared_ptr Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); + } + + inline std::shared_ptr Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); + } + + inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); + } + + inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + inline std::shared_ptr Client::Options(const char *path) { + return Options(path, Headers()); + } + + inline std::shared_ptr Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; + } + + inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + } + + inline void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + } + + inline void Client::follow_location(bool on) { follow_location_ = on; } /* * SSL Implementation */ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT -namespace detail { + namespace detail { -template -inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, - SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup, - T callback) { - assert(keep_alive_max_count > 0); + template + inline bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, + std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } - - if (!ssl) { - close_socket(sock); - return false; - } - - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - SSL_set_bio(ssl, bio, bio); - - if (!setup(ssl)) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - - close_socket(sock); - return false; - } - - bool ret = false; - - if (SSL_connect_or_accept(ssl) == 1) { - if (keep_alive_max_count > 1) { - auto count = keep_alive_max_count; - while (count > 0 && - (is_client_request || - detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { - SSLSocketStream strm(sock, ssl); - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(ssl, strm, last_connection, connection_close); - if (!ret || connection_close) { break; } - - count--; + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); } - } else { - SSLSocketStream strm(sock, ssl); - auto dummy_connection_close = false; - ret = callback(ssl, strm, true, dummy_connection_close); + + if (!ssl) { + close_socket(sock); + return false; + } + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl)) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + return false; + } + + bool ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; } - } - - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - - close_socket(sock); - - return ret; -} #if OPENSSL_VERSION_NUMBER < 0x10100000L -static std::shared_ptr> openSSL_locks_; + static std::shared_ptr> openSSL_locks_; class SSLThreadLocks { public: @@ -3386,183 +3911,185 @@ private: #endif -class SSLInit { -public: - SSLInit() { -#if OPENSSL_VERSION_NUMBER >= 0x1010001fL - OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); -#else - SSL_load_error_strings(); - SSL_library_init(); -#endif - } - - ~SSLInit() { + class SSLInit { + public: + SSLInit() { #if OPENSSL_VERSION_NUMBER < 0x1010001fL - ERR_free_strings(); + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); #endif - } + } -private: + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + + private: #if OPENSSL_VERSION_NUMBER < 0x10100000L - SSLThreadLocks thread_init_; + SSLThreadLocks thread_init_; #endif -}; + }; -static SSLInit sslinit_; + static SSLInit sslinit_; -} // namespace detail + } // namespace detail // SSL socket stream implementation -inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl) - : sock_(sock), ssl_(ssl) {} + inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} -inline SSLSocketStream::~SSLSocketStream() {} + inline SSLSocketStream::~SSLSocketStream() {} -inline int SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0 || - detail::select_read(sock_, CPPHTTPLIB_READ_TIMEOUT_SECOND, - CPPHTTPLIB_READ_TIMEOUT_USECOND) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); + inline int SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; } - return -1; -} -inline int SSLSocketStream::write(const char *ptr, size_t size) { - return SSL_write(ssl_, ptr, static_cast(size)); -} + inline int SSLSocketStream::write(const char *ptr, size_t size) { + return SSL_write(ssl_, ptr, static_cast(size)); + } -inline int SSLSocketStream::write(const char *ptr) { - return write(ptr, strlen(ptr)); -} + inline int SSLSocketStream::write(const char *ptr) { + return write(ptr, strlen(ptr)); + } -inline int SSLSocketStream::write(const std::string &s) { - return write(s.data(), s.size()); -} + inline int SSLSocketStream::write(const std::string &s) { + return write(s.data(), s.size()); + } -inline std::string SSLSocketStream::get_remote_addr() const { - return detail::get_remote_addr(sock_); -} + inline std::string SSLSocketStream::get_remote_addr() const { + return detail::get_remote_addr(sock_); + } // SSL HTTP server implementation -inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path, - const char *client_ca_cert_dir_path) { - ctx_ = SSL_CTX_new(SSLv23_server_method()); + inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); - // EC_KEY_free(ecdh); + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - // if (client_ca_cert_file_path) { - // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); - // SSL_CTX_set_client_CA_list(ctx_, list); - // } + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); - SSL_CTX_set_verify( + SSL_CTX_set_verify( ctx_, SSL_VERIFY_PEER | - SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, nullptr); + } } } -} -inline SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } -} + inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } + } -inline bool SSLServer::is_valid() const { return ctx_; } + inline bool SSLServer::is_valid() const { return ctx_; } -inline bool SSLServer::process_and_close_socket(socket_t sock) { - return detail::process_and_close_socket_ssl( - false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, - [](SSL * /*ssl*/) { return true; }, + inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, [this](SSL *ssl, Stream &strm, bool last_connection, bool &connection_close) { return process_request(strm, last_connection, connection_close, [&](Request &req) { req.ssl = ssl; }); }); -} + } // SSL HTTP client implementation -inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec, - const char *client_cert_path, - const char *client_key_path) + inline SSLClient::SSLClient(const char *host, int port, time_t timeout_sec, + const char *client_cert_path, + const char *client_key_path) : Client(host, port, timeout_sec) { - ctx_ = SSL_CTX_new(SSLv23_client_method()); + ctx_ = SSL_CTX_new(SSLv23_client_method()); - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(std::string(b, e)); - }); - if (client_cert_path && client_key_path) { - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != - 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert_path && client_key_path) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path, + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } } } -} -inline SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } -} + inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + } -inline bool SSLClient::is_valid() const { return ctx_; } + inline bool SSLClient::is_valid() const { return ctx_; } -inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, - const char *ca_cert_dir_path) { - if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } - if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } -} + inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } + } -inline void SSLClient::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; -} + inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; + } -inline long SSLClient::get_openssl_verify_result() const { - return verify_result_; -} + inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; + } -inline SSL_CTX* SSLClient::ssl_context() const noexcept { - return ctx_; -} + inline SSL_CTX *SSLClient::ssl_context() const noexcept { return ctx_; } -inline bool SSLClient::process_and_close_socket( + inline bool SSLClient::process_and_close_socket( socket_t sock, size_t request_count, std::function - callback) { + callback) { - request_count = std::min(request_count, keep_alive_max_count_); + request_count = std::min(request_count, keep_alive_max_count_); - return is_valid() && - detail::process_and_close_socket_ssl( - true, sock, request_count, ctx_, ctx_mutex_, + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, [&](SSL *ssl) { if (ca_cert_file_path_.empty()) { SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); } else { if (!SSL_CTX_load_verify_locations( - ctx_, ca_cert_file_path_.c_str(), nullptr)) { + ctx_, ca_cert_file_path_.c_str(), nullptr)) { return false; } SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); @@ -3596,137 +4123,137 @@ inline bool SSLClient::process_and_close_socket( bool &connection_close) { return callback(strm, last_connection, connection_close); }); -} + } -inline bool SSLClient::is_ssl() const { return true; } + inline bool SSLClient::is_ssl() const { return true; } -inline bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" + inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); -} + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); + } -inline bool -SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; + inline bool + SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; - auto type = GEN_DNS; + auto type = GEN_DNS; - struct in6_addr addr6; - struct in_addr addr; - size_t addr_len = 0; + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; #ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); - } + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } #endif - auto alt_names = static_cast( + auto alt_names = static_cast( X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - if (alt_names) { - auto dsn_matched = false; - auto ip_mached = false; + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; - auto count = sk_GENERAL_NAME_num(alt_names); + auto count = sk_GENERAL_NAME_num(alt_names); - for (auto i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (val->type == type) { - auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); - auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); - if (strlen(name) == name_len) { - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || - !memcmp(&addr, name, addr_len)) { - ip_mached = true; + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; } - break; } } } + + if (dsn_matched || ip_mached) { ret = true; } } - if (dsn_matched || ip_mached) { ret = true; } + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; } - GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); - return ret; -} + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); -inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); - - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); - - if (name_len != -1) { return check_host_name(name, name_len); } - } - - return false; -} - -inline bool SSLClient::check_host_name(const char *pattern, - size_t pattern_len) const { - if (host_.size() == pattern_len && host_ == pattern) { return true; } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(std::string(b, e)); - }); - - if (host_components_.size() != pattern_components.size()) { return false; } - - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (p != h && p != "*") { - auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && - !p.compare(0, p.size() - 1, h)); - if (!partial_match) { return false; } + if (name_len != -1) { return check_host_name(name, name_len); } } - ++itr; + + return false; } - return true; -} + inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; + } #endif } // namespace httplib -#endif // CPPHTTPLIB_HTTPLIB_H +#endif // CPPHTTPLIB_HTTPLIB_H \ No newline at end of file diff --git a/src/backend/common/Workers.cpp b/src/backend/common/Workers.cpp index eb348328..6ef09401 100644 --- a/src/backend/common/Workers.cpp +++ b/src/backend/common/Workers.cpp @@ -96,6 +96,11 @@ void xmrig::Workers::start(const std::vector &data) for (Thread *worker : m_workers) { worker->start(Workers::onReady); + + // This sleep is important for optimal caching! + // Threads must allocate scratchpads in order so that adjacent cores will use adjacent scratchpads + // Sub-optimal caching can result in up to 0.5% hashrate penalty + std::this_thread::sleep_for(std::chrono::milliseconds(20)); } } diff --git a/src/backend/cpu/CpuConfig.cpp b/src/backend/cpu/CpuConfig.cpp index bcf26f7c..75648578 100644 --- a/src/backend/cpu/CpuConfig.cpp +++ b/src/backend/cpu/CpuConfig.cpp @@ -65,6 +65,7 @@ static const char *kCnExtremelite = "cn-extremelite"; static const char *kRx = "rx"; static const char *kRxWOW = "rx/wow"; static const char *kRxARQ = "rx/arq"; +static const char *kRxSFX = "rx/sfx"; #endif #ifdef XMRIG_ALGO_ARGON2 @@ -200,6 +201,7 @@ void xmrig::CpuConfig::generate() m_threads.move(kRx, cpu->threads(Algorithm::RX_0)); m_threads.move(kRxWOW, cpu->threads(Algorithm::RX_WOW)); m_threads.move(kRxARQ, cpu->threads(Algorithm::RX_ARQ)); + m_threads.move(kRxSFX, cpu->threads(Algorithm::RX_SFX)); # endif generateArgon2(); diff --git a/src/backend/cpu/CpuWorker.cpp b/src/backend/cpu/CpuWorker.cpp index edb0ad5c..9c1c361c 100644 --- a/src/backend/cpu/CpuWorker.cpp +++ b/src/backend/cpu/CpuWorker.cpp @@ -191,8 +191,17 @@ void xmrig::CpuWorker::start() consumeJob(); } + uint64_t storeStatsMask = 7; + +# ifdef XMRIG_ALGO_RANDOMX + // RandomX is faster, we don't need to store stats so often + if (m_job.currentJob().algorithm().family() == Algorithm::RANDOM_X) { + storeStatsMask = 63; + } +# endif + while (!Nonce::isOutdated(Nonce::CPU, m_job.sequence())) { - if ((m_count & 0x7) == 0) { + if ((m_count & storeStatsMask) == 0) { storeStats(); } diff --git a/src/cc/CCClient.cpp b/src/cc/CCClient.cpp index f797faf6..416b3fbb 100644 --- a/src/cc/CCClient.cpp +++ b/src/cc/CCClient.cpp @@ -379,6 +379,8 @@ std::shared_ptr xmrig::CCClient::performRequest(const std::st auto res = std::make_shared(); + cli->follow_location(false); + return cli->send(req, *res) ? res : nullptr; } diff --git a/src/cc/Httpd.cpp b/src/cc/Httpd.cpp index 7ea3deea..1d4975b6 100644 --- a/src/cc/Httpd.cpp +++ b/src/cc/Httpd.cpp @@ -20,7 +20,6 @@ #include #include #include <3rdparty/cpp-httplib/httplib.h> -#include <3rdparty/base64/base64.h> #include "base/io/log/Log.h" @@ -147,31 +146,33 @@ int Httpd::basicAuth(const httplib::Request& req, httplib::Response& res) { int result = HTTP_UNAUTHORIZED; + std::string removeAddr = req.get_header_value("REMOTE_ADDR"); + if (m_config->adminUser().empty() || m_config->adminPass().empty()) { res.set_content(std::string("" "Please configure admin user and pass to view this Page." ""), CONTENT_TYPE_HTML); - LOG_ERR("[%s] 403 FORBIDDEN - Admin user/password not set!", req.remoteAddr.c_str()); + LOG_ERR("[%s] 403 FORBIDDEN - Admin user/password not set!", removeAddr.c_str()); result = HTTP_FORBIDDEN; } else { auto authHeader = req.get_header_value("Authorization"); - auto credentials = std::string("Basic ") + Base64::Encode(m_config->adminUser() + std::string(":") + m_config->adminPass()); + auto credentials = httplib::make_basic_authentication_header(m_config->adminUser(), m_config->adminPass()); - if (!authHeader.empty() && credentials == authHeader) + if (!authHeader.empty() && credentials.second == authHeader) { result = HTTP_OK; } else if (authHeader.empty()) { - LOG_WARN("[%s] 401 UNAUTHORIZED", req.remoteAddr.c_str()); + LOG_WARN("[%s] 401 UNAUTHORIZED", removeAddr.c_str()); } else { - LOG_ERR("[%s] 403 FORBIDDEN - Admin user/password wrong!", req.remoteAddr.c_str()); + LOG_ERR("[%s] 403 FORBIDDEN - Admin user/password wrong!", removeAddr.c_str()); } } @@ -184,9 +185,11 @@ int Httpd::bearerAuth(const httplib::Request& req, httplib::Response& res) { int result = HTTP_UNAUTHORIZED; + std::string removeAddr = req.get_header_value("REMOTE_ADDR"); + if (m_config->token().empty()) { - LOG_WARN("[%s] 200 OK - WARNING AccessToken not set!", req.remoteAddr.c_str()); + LOG_WARN("[%s] 200 OK - WARNING AccessToken not set!", removeAddr.c_str()); result = HTTP_OK; } else @@ -200,11 +203,11 @@ int Httpd::bearerAuth(const httplib::Request& req, httplib::Response& res) } else if (authHeader.empty()) { - LOG_WARN("[%s] 401 UNAUTHORIZED", req.remoteAddr.c_str()); + LOG_WARN("[%s] 401 UNAUTHORIZED", removeAddr.c_str()); } else { - LOG_ERR("[%s] 403 FORBIDDEN - AccessToken wrong!", req.remoteAddr.c_str()); + LOG_ERR("[%s] 403 FORBIDDEN - AccessToken wrong!", removeAddr.c_str()); result = HTTP_FORBIDDEN; } } diff --git a/src/cc/Service.cpp b/src/cc/Service.cpp index 39c4260e..62a752df 100644 --- a/src/cc/Service.cpp +++ b/src/cc/Service.cpp @@ -107,8 +107,9 @@ int Service::handleGET(const httplib::Request& req, httplib::Response& res) int resultCode = HTTP_NOT_FOUND; std::string clientId = req.get_param_value("clientId"); + std::string removeAddr = req.get_header_value("REMOTE_ADDR"); - LOG_INFO("[%s] GET %s%s%s", req.remoteAddr.c_str(), req.path.c_str(), clientId.empty() ? "" : "/?clientId=", clientId.c_str()); + LOG_INFO("[%s] GET %s%s%s", removeAddr.c_str(), req.path.c_str(), clientId.empty() ? "" : "/?clientId=", clientId.c_str()); if (req.path == "/") { @@ -140,14 +141,14 @@ int Service::handleGET(const httplib::Request& req, httplib::Response& res) } else { - LOG_WARN("[%s] 404 NOT FOUND (%s)", req.remoteAddr.c_str(), req.path.c_str()); + LOG_WARN("[%s] 404 NOT FOUND (%s)", removeAddr.c_str(), req.path.c_str()); } } else { resultCode = HTTP_BAD_REQUEST; LOG_ERR("[%s] 400 BAD REQUEST - Request does not contain clientId (%s)", - req.remoteAddr.c_str(), req.path.c_str()); + removeAddr.c_str(), req.path.c_str()); } } @@ -161,8 +162,9 @@ int Service::handlePOST(const httplib::Request& req, httplib::Response& res) int resultCode = HTTP_NOT_FOUND; std::string clientId = req.get_param_value("clientId"); + std::string removeAddr = req.get_header_value("REMOTE_ADDR"); - LOG_INFO("[%s] POST %s%s%s", req.remoteAddr.c_str(), req.path.c_str(), clientId.empty() ? "" : "/?clientId=", clientId.c_str()); + LOG_INFO("[%s] POST %s%s%s", removeAddr.c_str(), req.path.c_str(), clientId.empty() ? "" : "/?clientId=", clientId.c_str()); if (!clientId.empty()) { @@ -185,7 +187,7 @@ int Service::handlePOST(const httplib::Request& req, httplib::Response& res) else { resultCode = HTTP_BAD_REQUEST; - LOG_WARN("[%s] 400 BAD REQUEST - Request does not contain clientId (%s)", req.remoteAddr.c_str(), req.path.c_str()); + LOG_WARN("[%s] 400 BAD REQUEST - Request does not contain clientId (%s)", removeAddr.c_str(), req.path.c_str()); } } else @@ -196,7 +198,7 @@ int Service::handlePOST(const httplib::Request& req, httplib::Response& res) } else { - LOG_WARN("[%s] 404 NOT FOUND (%s)", req.remoteAddr.c_str(), req.path.c_str()); + LOG_WARN("[%s] 404 NOT FOUND (%s)", removeAddr.c_str(), req.path.c_str()); } } @@ -271,12 +273,14 @@ int Service::setClientStatus(const httplib::Request& req, const std::string& cli { int resultCode = HTTP_BAD_REQUEST; + std::string removeAddr = req.get_header_value("REMOTE_ADDR"); + rapidjson::Document document; if (!document.Parse(req.body.c_str()).HasParseError()) { ClientStatus clientStatus; clientStatus.parseFromJson(document); - clientStatus.setExternalIp(req.remoteAddr); + clientStatus.setExternalIp(removeAddr); setClientLog(static_cast(m_config->clientLogHistory()), clientId, clientStatus.getLog()); @@ -294,7 +298,7 @@ int Service::setClientStatus(const httplib::Request& req, const std::string& cli else { LOG_ERR("[%s] ClientStatus for client '%s' - Parse Error Occured: %d", - req.remoteAddr.c_str(), clientId.c_str(), document.GetParseError()); + removeAddr.c_str(), clientId.c_str(), document.GetParseError()); } return resultCode; diff --git a/src/crypto/cn/CnAlgo.h b/src/crypto/cn/CnAlgo.h index d51d5a1d..f8439395 100644 --- a/src/crypto/cn/CnAlgo.h +++ b/src/crypto/cn/CnAlgo.h @@ -183,6 +183,7 @@ private: 0, // RX_WOW 0, // RX_LOKI 0, // RX_ARQ + 0, // RX_SFX # endif # ifdef XMRIG_ALGO_ARGON2 0, // AR2_CHUKWA @@ -227,6 +228,7 @@ private: 0, // RX_WOW 0, // RX_LOKI 0, // RX_ARQ + 0, // RX_SFX # endif # ifdef XMRIG_ALGO_ARGON2 0, // AR2_CHUKWA @@ -271,6 +273,7 @@ private: Algorithm::INVALID, // RX_WOW Algorithm::INVALID, // RX_LOKI Algorithm::INVALID, // RX_ARQ + Algorithm::INVALID, // RX_SFX # endif # ifdef XMRIG_ALGO_ARGON2 Algorithm::INVALID, // AR2_CHUKWA diff --git a/src/crypto/common/Algorithm.cpp b/src/crypto/common/Algorithm.cpp index 15afcb90..a9c433f7 100644 --- a/src/crypto/common/Algorithm.cpp +++ b/src/crypto/common/Algorithm.cpp @@ -125,6 +125,8 @@ static AlgoName const algorithm_names[] = { { "RandomXL", nullptr, Algorithm::RX_LOKI }, { "randomx/arq", "rx/arq", Algorithm::RX_ARQ }, { "RandomARQ", nullptr, Algorithm::RX_ARQ }, + { "randomx/sfx", "rx/sfx", Algorithm::RX_SFX }, + { "RandomSFX", nullptr, Algorithm::RX_SFX }, # endif # ifdef XMRIG_ALGO_ARGON2 { "argon2/chukwa", nullptr, Algorithm::AR2_CHUKWA }, @@ -155,6 +157,7 @@ size_t xmrig::Algorithm::l2() const switch (m_id) { case RX_0: case RX_LOKI: + case RX_SFX: return 0x40000; case RX_WOW: @@ -188,6 +191,7 @@ size_t xmrig::Algorithm::l3() const switch (m_id) { case RX_0: case RX_LOKI: + case RX_SFX: return oneMiB * 2; case RX_WOW: @@ -294,6 +298,7 @@ xmrig::Algorithm::Family xmrig::Algorithm::family(Id id) case RX_WOW: case RX_LOKI: case RX_ARQ: + case RX_SFX: return RANDOM_X; # endif diff --git a/src/crypto/common/Algorithm.h b/src/crypto/common/Algorithm.h index e7855e1a..c999548c 100644 --- a/src/crypto/common/Algorithm.h +++ b/src/crypto/common/Algorithm.h @@ -77,6 +77,7 @@ public: RX_WOW, // "rx/wow" RandomWOW (Wownero). RX_LOKI, // "rx/loki" RandomXL (Loki). RX_ARQ, // "rx/arq" RandomARQ (Arqma). + RX_SFX, // "rx/sfx" RandomSFX (Safex). # endif # ifdef XMRIG_ALGO_ARGON2 AR2_CHUKWA, // "argon2/chukwa" diff --git a/src/crypto/randomx/aes_hash.cpp b/src/crypto/randomx/aes_hash.cpp index fe149dfe..d7216be7 100644 --- a/src/crypto/randomx/aes_hash.cpp +++ b/src/crypto/randomx/aes_hash.cpp @@ -51,52 +51,52 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ template void hashAes1Rx4(const void *input, size_t inputSize, void *hash) { - const uint8_t* inptr = (uint8_t*)input; - const uint8_t* inputEnd = inptr + inputSize; + const uint8_t* inptr = (uint8_t*)input; + const uint8_t* inputEnd = inptr + inputSize; - rx_vec_i128 state0, state1, state2, state3; - rx_vec_i128 in0, in1, in2, in3; + rx_vec_i128 state0, state1, state2, state3; + rx_vec_i128 in0, in1, in2, in3; - //intial state - state0 = rx_set_int_vec_i128(AES_HASH_1R_STATE0); - state1 = rx_set_int_vec_i128(AES_HASH_1R_STATE1); - state2 = rx_set_int_vec_i128(AES_HASH_1R_STATE2); - state3 = rx_set_int_vec_i128(AES_HASH_1R_STATE3); + //intial state + state0 = rx_set_int_vec_i128(AES_HASH_1R_STATE0); + state1 = rx_set_int_vec_i128(AES_HASH_1R_STATE1); + state2 = rx_set_int_vec_i128(AES_HASH_1R_STATE2); + state3 = rx_set_int_vec_i128(AES_HASH_1R_STATE3); - //process 64 bytes at a time in 4 lanes - while (inptr < inputEnd) { - in0 = rx_load_vec_i128((rx_vec_i128*)inptr + 0); - in1 = rx_load_vec_i128((rx_vec_i128*)inptr + 1); - in2 = rx_load_vec_i128((rx_vec_i128*)inptr + 2); - in3 = rx_load_vec_i128((rx_vec_i128*)inptr + 3); + //process 64 bytes at a time in 4 lanes + while (inptr < inputEnd) { + in0 = rx_load_vec_i128((rx_vec_i128*)inptr + 0); + in1 = rx_load_vec_i128((rx_vec_i128*)inptr + 1); + in2 = rx_load_vec_i128((rx_vec_i128*)inptr + 2); + in3 = rx_load_vec_i128((rx_vec_i128*)inptr + 3); - state0 = aesenc(state0, in0); - state1 = aesdec(state1, in1); - state2 = aesenc(state2, in2); - state3 = aesdec(state3, in3); + state0 = aesenc(state0, in0); + state1 = aesdec(state1, in1); + state2 = aesenc(state2, in2); + state3 = aesdec(state3, in3); - inptr += 64; - } + inptr += 64; + } - //two extra rounds to achieve full diffusion - rx_vec_i128 xkey0 = rx_set_int_vec_i128(AES_HASH_1R_XKEY0); - rx_vec_i128 xkey1 = rx_set_int_vec_i128(AES_HASH_1R_XKEY1); + //two extra rounds to achieve full diffusion + rx_vec_i128 xkey0 = rx_set_int_vec_i128(AES_HASH_1R_XKEY0); + rx_vec_i128 xkey1 = rx_set_int_vec_i128(AES_HASH_1R_XKEY1); - state0 = aesenc(state0, xkey0); - state1 = aesdec(state1, xkey0); - state2 = aesenc(state2, xkey0); - state3 = aesdec(state3, xkey0); + state0 = aesenc(state0, xkey0); + state1 = aesdec(state1, xkey0); + state2 = aesenc(state2, xkey0); + state3 = aesdec(state3, xkey0); - state0 = aesenc(state0, xkey1); - state1 = aesdec(state1, xkey1); - state2 = aesenc(state2, xkey1); - state3 = aesdec(state3, xkey1); + state0 = aesenc(state0, xkey1); + state1 = aesdec(state1, xkey1); + state2 = aesenc(state2, xkey1); + state3 = aesdec(state3, xkey1); - //output hash - rx_store_vec_i128((rx_vec_i128*)hash + 0, state0); - rx_store_vec_i128((rx_vec_i128*)hash + 1, state1); - rx_store_vec_i128((rx_vec_i128*)hash + 2, state2); - rx_store_vec_i128((rx_vec_i128*)hash + 3, state3); + //output hash + rx_store_vec_i128((rx_vec_i128*)hash + 0, state0); + rx_store_vec_i128((rx_vec_i128*)hash + 1, state1); + rx_store_vec_i128((rx_vec_i128*)hash + 2, state2); + rx_store_vec_i128((rx_vec_i128*)hash + 3, state3); } template void hashAes1Rx4(const void *input, size_t inputSize, void *hash); @@ -119,40 +119,40 @@ template void hashAes1Rx4(const void *input, size_t inputSize, void *hash) */ template void fillAes1Rx4(void *state, size_t outputSize, void *buffer) { - const uint8_t* outptr = (uint8_t*)buffer; - const uint8_t* outputEnd = outptr + outputSize; + const uint8_t* outptr = (uint8_t*)buffer; + const uint8_t* outputEnd = outptr + outputSize; - rx_vec_i128 state0, state1, state2, state3; - rx_vec_i128 key0, key1, key2, key3; + rx_vec_i128 state0, state1, state2, state3; + rx_vec_i128 key0, key1, key2, key3; - key0 = rx_set_int_vec_i128(AES_GEN_1R_KEY0); - key1 = rx_set_int_vec_i128(AES_GEN_1R_KEY1); - key2 = rx_set_int_vec_i128(AES_GEN_1R_KEY2); - key3 = rx_set_int_vec_i128(AES_GEN_1R_KEY3); + key0 = rx_set_int_vec_i128(AES_GEN_1R_KEY0); + key1 = rx_set_int_vec_i128(AES_GEN_1R_KEY1); + key2 = rx_set_int_vec_i128(AES_GEN_1R_KEY2); + key3 = rx_set_int_vec_i128(AES_GEN_1R_KEY3); - state0 = rx_load_vec_i128((rx_vec_i128*)state + 0); - state1 = rx_load_vec_i128((rx_vec_i128*)state + 1); - state2 = rx_load_vec_i128((rx_vec_i128*)state + 2); - state3 = rx_load_vec_i128((rx_vec_i128*)state + 3); + state0 = rx_load_vec_i128((rx_vec_i128*)state + 0); + state1 = rx_load_vec_i128((rx_vec_i128*)state + 1); + state2 = rx_load_vec_i128((rx_vec_i128*)state + 2); + state3 = rx_load_vec_i128((rx_vec_i128*)state + 3); - while (outptr < outputEnd) { - state0 = aesdec(state0, key0); - state1 = aesenc(state1, key1); - state2 = aesdec(state2, key2); - state3 = aesenc(state3, key3); + while (outptr < outputEnd) { + state0 = aesdec(state0, key0); + state1 = aesenc(state1, key1); + state2 = aesdec(state2, key2); + state3 = aesenc(state3, key3); - rx_store_vec_i128((rx_vec_i128*)outptr + 0, state0); - rx_store_vec_i128((rx_vec_i128*)outptr + 1, state1); - rx_store_vec_i128((rx_vec_i128*)outptr + 2, state2); - rx_store_vec_i128((rx_vec_i128*)outptr + 3, state3); + rx_store_vec_i128((rx_vec_i128*)outptr + 0, state0); + rx_store_vec_i128((rx_vec_i128*)outptr + 1, state1); + rx_store_vec_i128((rx_vec_i128*)outptr + 2, state2); + rx_store_vec_i128((rx_vec_i128*)outptr + 3, state3); - outptr += 64; - } + outptr += 64; + } - rx_store_vec_i128((rx_vec_i128*)state + 0, state0); - rx_store_vec_i128((rx_vec_i128*)state + 1, state1); - rx_store_vec_i128((rx_vec_i128*)state + 2, state2); - rx_store_vec_i128((rx_vec_i128*)state + 3, state3); + rx_store_vec_i128((rx_vec_i128*)state + 0, state0); + rx_store_vec_i128((rx_vec_i128*)state + 1, state1); + rx_store_vec_i128((rx_vec_i128*)state + 2, state2); + rx_store_vec_i128((rx_vec_i128*)state + 3, state3); } template void fillAes1Rx4(void *state, size_t outputSize, void *buffer); @@ -160,55 +160,136 @@ template void fillAes1Rx4(void *state, size_t outputSize, void *buffer); template void fillAes4Rx4(void *state, size_t outputSize, void *buffer) { - const uint8_t* outptr = (uint8_t*)buffer; - const uint8_t* outputEnd = outptr + outputSize; + const uint8_t* outptr = (uint8_t*)buffer; + const uint8_t* outputEnd = outptr + outputSize; - rx_vec_i128 state0, state1, state2, state3; - rx_vec_i128 key0, key1, key2, key3, key4, key5, key6, key7; + rx_vec_i128 state0, state1, state2, state3; + rx_vec_i128 key0, key1, key2, key3, key4, key5, key6, key7; - key0 = RandomX_CurrentConfig.fillAes4Rx4_Key[0]; - key1 = RandomX_CurrentConfig.fillAes4Rx4_Key[1]; - key2 = RandomX_CurrentConfig.fillAes4Rx4_Key[2]; - key3 = RandomX_CurrentConfig.fillAes4Rx4_Key[3]; - key4 = RandomX_CurrentConfig.fillAes4Rx4_Key[4]; - key5 = RandomX_CurrentConfig.fillAes4Rx4_Key[5]; - key6 = RandomX_CurrentConfig.fillAes4Rx4_Key[6]; - key7 = RandomX_CurrentConfig.fillAes4Rx4_Key[7]; + key0 = RandomX_CurrentConfig.fillAes4Rx4_Key[0]; + key1 = RandomX_CurrentConfig.fillAes4Rx4_Key[1]; + key2 = RandomX_CurrentConfig.fillAes4Rx4_Key[2]; + key3 = RandomX_CurrentConfig.fillAes4Rx4_Key[3]; + key4 = RandomX_CurrentConfig.fillAes4Rx4_Key[4]; + key5 = RandomX_CurrentConfig.fillAes4Rx4_Key[5]; + key6 = RandomX_CurrentConfig.fillAes4Rx4_Key[6]; + key7 = RandomX_CurrentConfig.fillAes4Rx4_Key[7]; - state0 = rx_load_vec_i128((rx_vec_i128*)state + 0); - state1 = rx_load_vec_i128((rx_vec_i128*)state + 1); - state2 = rx_load_vec_i128((rx_vec_i128*)state + 2); - state3 = rx_load_vec_i128((rx_vec_i128*)state + 3); + state0 = rx_load_vec_i128((rx_vec_i128*)state + 0); + state1 = rx_load_vec_i128((rx_vec_i128*)state + 1); + state2 = rx_load_vec_i128((rx_vec_i128*)state + 2); + state3 = rx_load_vec_i128((rx_vec_i128*)state + 3); - while (outptr < outputEnd) { - state0 = aesdec(state0, key0); - state1 = aesenc(state1, key0); - state2 = aesdec(state2, key4); - state3 = aesenc(state3, key4); + while (outptr < outputEnd) { + state0 = aesdec(state0, key0); + state1 = aesenc(state1, key0); + state2 = aesdec(state2, key4); + state3 = aesenc(state3, key4); - state0 = aesdec(state0, key1); - state1 = aesenc(state1, key1); - state2 = aesdec(state2, key5); - state3 = aesenc(state3, key5); + state0 = aesdec(state0, key1); + state1 = aesenc(state1, key1); + state2 = aesdec(state2, key5); + state3 = aesenc(state3, key5); - state0 = aesdec(state0, key2); - state1 = aesenc(state1, key2); - state2 = aesdec(state2, key6); - state3 = aesenc(state3, key6); + state0 = aesdec(state0, key2); + state1 = aesenc(state1, key2); + state2 = aesdec(state2, key6); + state3 = aesenc(state3, key6); - state0 = aesdec(state0, key3); - state1 = aesenc(state1, key3); - state2 = aesdec(state2, key7); - state3 = aesenc(state3, key7); + state0 = aesdec(state0, key3); + state1 = aesenc(state1, key3); + state2 = aesdec(state2, key7); + state3 = aesenc(state3, key7); - rx_store_vec_i128((rx_vec_i128*)outptr + 0, state0); - rx_store_vec_i128((rx_vec_i128*)outptr + 1, state1); - rx_store_vec_i128((rx_vec_i128*)outptr + 2, state2); - rx_store_vec_i128((rx_vec_i128*)outptr + 3, state3); + rx_store_vec_i128((rx_vec_i128*)outptr + 0, state0); + rx_store_vec_i128((rx_vec_i128*)outptr + 1, state1); + rx_store_vec_i128((rx_vec_i128*)outptr + 2, state2); + rx_store_vec_i128((rx_vec_i128*)outptr + 3, state3); - outptr += 64; - } + outptr += 64; + } } template void fillAes4Rx4(void *state, size_t outputSize, void *buffer); template void fillAes4Rx4(void *state, size_t outputSize, void *buffer); + +template +void hashAndFillAes1Rx4(void *scratchpad, size_t scratchpadSize, void *hash, void* fill_state) { + uint8_t* scratchpadPtr = (uint8_t*)scratchpad; + const uint8_t* scratchpadEnd = scratchpadPtr + scratchpadSize; + + // initial state + rx_vec_i128 hash_state0 = rx_set_int_vec_i128(AES_HASH_1R_STATE0); + rx_vec_i128 hash_state1 = rx_set_int_vec_i128(AES_HASH_1R_STATE1); + rx_vec_i128 hash_state2 = rx_set_int_vec_i128(AES_HASH_1R_STATE2); + rx_vec_i128 hash_state3 = rx_set_int_vec_i128(AES_HASH_1R_STATE3); + + const rx_vec_i128 key0 = rx_set_int_vec_i128(AES_GEN_1R_KEY0); + const rx_vec_i128 key1 = rx_set_int_vec_i128(AES_GEN_1R_KEY1); + const rx_vec_i128 key2 = rx_set_int_vec_i128(AES_GEN_1R_KEY2); + const rx_vec_i128 key3 = rx_set_int_vec_i128(AES_GEN_1R_KEY3); + + rx_vec_i128 fill_state0 = rx_load_vec_i128((rx_vec_i128*)fill_state + 0); + rx_vec_i128 fill_state1 = rx_load_vec_i128((rx_vec_i128*)fill_state + 1); + rx_vec_i128 fill_state2 = rx_load_vec_i128((rx_vec_i128*)fill_state + 2); + rx_vec_i128 fill_state3 = rx_load_vec_i128((rx_vec_i128*)fill_state + 3); + + constexpr int PREFETCH_DISTANCE = 4096; + const char* prefetchPtr = ((const char*)scratchpad) + PREFETCH_DISTANCE; + scratchpadEnd -= PREFETCH_DISTANCE; + + for (int i = 0; i < 2; ++i) { + //process 64 bytes at a time in 4 lanes + while (scratchpadPtr < scratchpadEnd) { + hash_state0 = aesenc(hash_state0, rx_load_vec_i128((rx_vec_i128*)scratchpadPtr + 0)); + hash_state1 = aesdec(hash_state1, rx_load_vec_i128((rx_vec_i128*)scratchpadPtr + 1)); + hash_state2 = aesenc(hash_state2, rx_load_vec_i128((rx_vec_i128*)scratchpadPtr + 2)); + hash_state3 = aesdec(hash_state3, rx_load_vec_i128((rx_vec_i128*)scratchpadPtr + 3)); + + fill_state0 = aesdec(fill_state0, key0); + fill_state1 = aesenc(fill_state1, key1); + fill_state2 = aesdec(fill_state2, key2); + fill_state3 = aesenc(fill_state3, key3); + + rx_store_vec_i128((rx_vec_i128*)scratchpadPtr + 0, fill_state0); + rx_store_vec_i128((rx_vec_i128*)scratchpadPtr + 1, fill_state1); + rx_store_vec_i128((rx_vec_i128*)scratchpadPtr + 2, fill_state2); + rx_store_vec_i128((rx_vec_i128*)scratchpadPtr + 3, fill_state3); + + rx_prefetch_t0(prefetchPtr); + + scratchpadPtr += 64; + prefetchPtr += 64; + } + prefetchPtr = (const char*) scratchpad; + scratchpadEnd += PREFETCH_DISTANCE; + } + + rx_store_vec_i128((rx_vec_i128*)fill_state + 0, fill_state0); + rx_store_vec_i128((rx_vec_i128*)fill_state + 1, fill_state1); + rx_store_vec_i128((rx_vec_i128*)fill_state + 2, fill_state2); + rx_store_vec_i128((rx_vec_i128*)fill_state + 3, fill_state3); + + //two extra rounds to achieve full diffusion + rx_vec_i128 xkey0 = rx_set_int_vec_i128(AES_HASH_1R_XKEY0); + rx_vec_i128 xkey1 = rx_set_int_vec_i128(AES_HASH_1R_XKEY1); + + hash_state0 = aesenc(hash_state0, xkey0); + hash_state1 = aesdec(hash_state1, xkey0); + hash_state2 = aesenc(hash_state2, xkey0); + hash_state3 = aesdec(hash_state3, xkey0); + + hash_state0 = aesenc(hash_state0, xkey1); + hash_state1 = aesdec(hash_state1, xkey1); + hash_state2 = aesenc(hash_state2, xkey1); + hash_state3 = aesdec(hash_state3, xkey1); + + //output hash + rx_store_vec_i128((rx_vec_i128*)hash + 0, hash_state0); + rx_store_vec_i128((rx_vec_i128*)hash + 1, hash_state1); + rx_store_vec_i128((rx_vec_i128*)hash + 2, hash_state2); + rx_store_vec_i128((rx_vec_i128*)hash + 3, hash_state3); +} + +template void hashAndFillAes1Rx4(void *scratchpad, size_t scratchpadSize, void *hash, void* fill_state); +template void hashAndFillAes1Rx4(void *scratchpad, size_t scratchpadSize, void *hash, void* fill_state); \ No newline at end of file diff --git a/src/crypto/randomx/aes_hash.hpp b/src/crypto/randomx/aes_hash.hpp index b4d0e940..0dc87d25 100644 --- a/src/crypto/randomx/aes_hash.hpp +++ b/src/crypto/randomx/aes_hash.hpp @@ -38,3 +38,6 @@ void fillAes1Rx4(void *state, size_t outputSize, void *buffer); template void fillAes4Rx4(void *state, size_t outputSize, void *buffer); + +template +void hashAndFillAes1Rx4(void *scratchpad, size_t scratchpadSize, void *hash, void* fill_state); \ No newline at end of file diff --git a/src/crypto/randomx/intrin_portable.h b/src/crypto/randomx/intrin_portable.h index e4916096..76bfd874 100644 --- a/src/crypto/randomx/intrin_portable.h +++ b/src/crypto/randomx/intrin_portable.h @@ -102,6 +102,7 @@ typedef __m128d rx_vec_f128; #define rx_aligned_alloc(a, b) _mm_malloc(a,b) #define rx_aligned_free(a) _mm_free(a) #define rx_prefetch_nta(x) _mm_prefetch((const char *)(x), _MM_HINT_NTA) +#define rx_prefetch_t0(x) _mm_prefetch((const char *)(x), _MM_HINT_T0) #define rx_load_vec_f128 _mm_load_pd #define rx_store_vec_f128 _mm_store_pd @@ -201,6 +202,7 @@ typedef union{ #define rx_aligned_alloc(a, b) malloc(a) #define rx_aligned_free(a) free(a) #define rx_prefetch_nta(x) +#define rx_prefetch_t0(x) /* Splat 64-bit long long to 2 64-bit long longs */ FORCE_INLINE __m128i vec_splat2sd (int64_t scalar) @@ -376,11 +378,142 @@ FORCE_INLINE rx_vec_f128 rx_cvt_packed_int_vec_f128(const void* addr) { #define RANDOMX_DEFAULT_FENV -void rx_reset_float_state(); +#elif defined(__aarch64__) -void rx_set_rounding_mode(uint32_t mode); +#include +#include +#include -#else //end altivec +typedef uint8x16_t rx_vec_i128; +typedef float64x2_t rx_vec_f128; + +inline void* rx_aligned_alloc(size_t size, size_t align) { + void* p; + if (posix_memalign(&p, align, size) == 0) + return p; + + return 0; +}; + +#define rx_aligned_free(a) free(a) + +inline void rx_prefetch_nta(void* ptr) { + asm volatile ("prfm pldl1strm, [%0]\n" : : "r" (ptr)); +} + +inline void rx_prefetch_t0(const void* ptr) { + asm volatile ("prfm pldl1strm, [%0]\n" : : "r" (ptr)); +} + +FORCE_INLINE rx_vec_f128 rx_load_vec_f128(const double* pd) { + return vld1q_f64((const float64_t*)pd); +} + +FORCE_INLINE void rx_store_vec_f128(double* mem_addr, rx_vec_f128 val) { + vst1q_f64((float64_t*)mem_addr, val); +} + +FORCE_INLINE rx_vec_f128 rx_swap_vec_f128(rx_vec_f128 a) { + float64x2_t temp; + temp = vcopyq_laneq_f64(temp, 1, a, 1); + a = vcopyq_laneq_f64(a, 1, a, 0); + return vcopyq_laneq_f64(a, 0, temp, 1); +} + +FORCE_INLINE rx_vec_f128 rx_set_vec_f128(uint64_t x1, uint64_t x0) { + uint64x2_t temp0 = vdupq_n_u64(x0); + uint64x2_t temp1 = vdupq_n_u64(x1); + return vreinterpretq_f64_u64(vcopyq_laneq_u64(temp0, 1, temp1, 0)); +} + +FORCE_INLINE rx_vec_f128 rx_set1_vec_f128(uint64_t x) { + return vreinterpretq_f64_u64(vdupq_n_u64(x)); +} + +#define rx_add_vec_f128 vaddq_f64 +#define rx_sub_vec_f128 vsubq_f64 +#define rx_mul_vec_f128 vmulq_f64 +#define rx_div_vec_f128 vdivq_f64 +#define rx_sqrt_vec_f128 vsqrtq_f64 + +FORCE_INLINE rx_vec_f128 rx_xor_vec_f128(rx_vec_f128 a, rx_vec_f128 b) { + return vreinterpretq_f64_u8(veorq_u8(vreinterpretq_u8_f64(a), vreinterpretq_u8_f64(b))); +} + +FORCE_INLINE rx_vec_f128 rx_and_vec_f128(rx_vec_f128 a, rx_vec_f128 b) { + return vreinterpretq_f64_u8(vandq_u8(vreinterpretq_u8_f64(a), vreinterpretq_u8_f64(b))); +} + +FORCE_INLINE rx_vec_f128 rx_or_vec_f128(rx_vec_f128 a, rx_vec_f128 b) { + return vreinterpretq_f64_u8(vorrq_u8(vreinterpretq_u8_f64(a), vreinterpretq_u8_f64(b))); +} + +#ifdef __ARM_FEATURE_CRYPTO + + +FORCE_INLINE rx_vec_i128 rx_aesenc_vec_i128(rx_vec_i128 a, rx_vec_i128 key) { + const uint8x16_t zero = { 0 }; + return vaesmcq_u8(vaeseq_u8(a, zero)) ^ key; +} + +FORCE_INLINE rx_vec_i128 rx_aesdec_vec_i128(rx_vec_i128 a, rx_vec_i128 key) { + const uint8x16_t zero = { 0 }; + return vaesimcq_u8(vaesdq_u8(a, zero)) ^ key; +} + +#define HAVE_AES + +#endif + +#define rx_xor_vec_i128 veorq_u8 + +FORCE_INLINE int rx_vec_i128_x(rx_vec_i128 a) { + return vgetq_lane_s32(vreinterpretq_s32_u8(a), 0); +} + +FORCE_INLINE int rx_vec_i128_y(rx_vec_i128 a) { + return vgetq_lane_s32(vreinterpretq_s32_u8(a), 1); +} + +FORCE_INLINE int rx_vec_i128_z(rx_vec_i128 a) { + return vgetq_lane_s32(vreinterpretq_s32_u8(a), 2); +} + +FORCE_INLINE int rx_vec_i128_w(rx_vec_i128 a) { + return vgetq_lane_s32(vreinterpretq_s32_u8(a), 3); +} + +FORCE_INLINE rx_vec_i128 rx_set_int_vec_i128(int _I3, int _I2, int _I1, int _I0) { + int32_t data[4]; + data[0] = _I0; + data[1] = _I1; + data[2] = _I2; + data[3] = _I3; + return vreinterpretq_u8_s32(vld1q_s32(data)); +}; + +#define rx_xor_vec_i128 veorq_u8 + +FORCE_INLINE rx_vec_i128 rx_load_vec_i128(const rx_vec_i128* mem_addr) { + return vld1q_u8((const uint8_t*)mem_addr); +} + +FORCE_INLINE void rx_store_vec_i128(rx_vec_i128* mem_addr, rx_vec_i128 val) { + vst1q_u8((uint8_t*)mem_addr, val); +} + +FORCE_INLINE rx_vec_f128 rx_cvt_packed_int_vec_f128(const void* addr) { + double lo = unsigned32ToSigned2sCompl(load32((uint8_t*)addr + 0)); + double hi = unsigned32ToSigned2sCompl(load32((uint8_t*)addr + 4)); + rx_vec_f128 x; + x = vsetq_lane_f64(lo, x, 0); + x = vsetq_lane_f64(hi, x, 1); + return x; +} + +#define RANDOMX_DEFAULT_FENV + +#else //portable fallback #include #include @@ -405,6 +538,7 @@ typedef union { #define rx_aligned_alloc(a, b) malloc(a) #define rx_aligned_free(a) free(a) #define rx_prefetch_nta(x) +#define rx_prefetch_t0(x) FORCE_INLINE rx_vec_f128 rx_load_vec_f128(const double* pd) { rx_vec_f128 x; @@ -487,7 +621,6 @@ FORCE_INLINE rx_vec_f128 rx_set1_vec_f128(uint64_t x) { return v; } - FORCE_INLINE rx_vec_f128 rx_xor_vec_f128(rx_vec_f128 a, rx_vec_f128 b) { rx_vec_f128 x; x.i.u64[0] = a.i.u64[0] ^ b.i.u64[0]; @@ -578,10 +711,6 @@ FORCE_INLINE rx_vec_f128 rx_cvt_packed_int_vec_f128(const void* addr) { #define RANDOMX_DEFAULT_FENV -void rx_reset_float_state(); - -void rx_set_rounding_mode(uint32_t mode); - #endif #ifndef HAVE_AES @@ -598,8 +727,16 @@ FORCE_INLINE rx_vec_i128 rx_aesdec_vec_i128(rx_vec_i128 v, rx_vec_i128 rkey) { } #endif +#ifdef RANDOMX_DEFAULT_FENV + +void rx_reset_float_state(); + +void rx_set_rounding_mode(uint32_t mode); + +#endif + double loadDoublePortable(const void* addr); uint64_t mulh(uint64_t, uint64_t); int64_t smulh(int64_t, int64_t); uint64_t rotl64(uint64_t, unsigned int); -uint64_t rotr64(uint64_t, unsigned int); +uint64_t rotr64(uint64_t, unsigned int); \ No newline at end of file diff --git a/src/crypto/randomx/jit_compiler_x86.cpp b/src/crypto/randomx/jit_compiler_x86.cpp index 2f6cfbda..a0fe0674 100644 --- a/src/crypto/randomx/jit_compiler_x86.cpp +++ b/src/crypto/randomx/jit_compiler_x86.cpp @@ -29,6 +29,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #include #include +#include #include "crypto/randomx/jit_compiler_x86.hpp" #include "crypto/randomx/jit_compiler_x86_static.hpp" #include "crypto/randomx/superscalar.hpp" @@ -36,6 +37,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "crypto/randomx/reciprocal.h" #include "crypto/randomx/virtual_memory.hpp" +#ifdef _MSC_VER +# include +#else +# include +#endif + namespace randomx { /* @@ -108,7 +115,7 @@ namespace randomx { const int32_t codeSshPrefetchSize = codeShhEnd - codeShhPrefetch; const int32_t codeSshInitSize = codeProgramEnd - codeShhInit; - const int32_t epilogueOffset = CodeSize - epilogueSize; + const int32_t epilogueOffset = (CodeSize - epilogueSize) & ~63; constexpr int32_t superScalarHashOffset = 32768; static const uint8_t REX_ADD_RR[] = { 0x4d, 0x03 }; @@ -183,6 +190,7 @@ namespace randomx { static const uint8_t REX_ADD_I[] = { 0x49, 0x81 }; static const uint8_t REX_TEST[] = { 0x49, 0xF7 }; static const uint8_t JZ[] = { 0x0f, 0x84 }; + static const uint8_t JZ_SHORT = 0x74; static const uint8_t RET = 0xc3; static const uint8_t LEA_32[] = { 0x41, 0x8d }; static const uint8_t MOVNTI[] = { 0x4c, 0x0f, 0xc3 }; @@ -197,20 +205,100 @@ namespace randomx { static const uint8_t NOP7[] = { 0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00 }; static const uint8_t NOP8[] = { 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00 }; -// static const uint8_t* NOPX[] = { NOP1, NOP2, NOP3, NOP4, NOP5, NOP6, NOP7, NOP8 }; + static const uint8_t* NOPX[] = { NOP1, NOP2, NOP3, NOP4, NOP5, NOP6, NOP7, NOP8 }; + + static const uint8_t JMP_ALIGN_PREFIX[14][16] = { + {}, + {0x2E}, + {0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x90, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x66, 0x90, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x66, 0x66, 0x90, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x0F, 0x1F, 0x40, 0x00, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + {0x0F, 0x1F, 0x44, 0x00, 0x00, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E, 0x2E}, + }; + + bool JitCompilerX86::BranchesWithin32B = false; size_t JitCompilerX86::getCodeSize() { return codePos < prologueSize ? 0 : codePos - prologueSize; } + static inline void cpuid(uint32_t level, int32_t output[4]) + { + memset(output, 0, sizeof(int32_t) * 4); + +# ifdef _MSC_VER + __cpuid(output, static_cast(level)); +# else + __cpuid_count(level, 0, output[0], output[1], output[2], output[3]); +# endif + } + + // CPU-specific tweaks + void JitCompilerX86::applyTweaks() { + int32_t info[4]; + cpuid(0, info); + + int32_t manufacturer[4]; + manufacturer[0] = info[1]; + manufacturer[1] = info[3]; + manufacturer[2] = info[2]; + manufacturer[3] = 0; + + if (strcmp((const char*)manufacturer, "GenuineIntel") == 0) { + struct + { + unsigned int stepping : 4; + unsigned int model : 4; + unsigned int family : 4; + unsigned int processor_type : 2; + unsigned int reserved1 : 2; + unsigned int ext_model : 4; + unsigned int ext_family : 8; + unsigned int reserved2 : 4; + } processor_info; + + cpuid(1, info); + memcpy(&processor_info, info, sizeof(processor_info)); + + // Intel JCC erratum mitigation + if (processor_info.family == 6) { + const uint32_t model = processor_info.model | (processor_info.ext_model << 4); + const uint32_t stepping = processor_info.stepping; + + // Affected CPU models and stepping numbers are taken from https://www.intel.com/content/dam/support/us/en/documents/processors/mitigations-jump-conditional-code-erratum.pdf + BranchesWithin32B = + ((model == 0x4E) && (stepping == 0x3)) || + ((model == 0x55) && (stepping == 0x4)) || + ((model == 0x5E) && (stepping == 0x3)) || + ((model == 0x8E) && (stepping >= 0x9) && (stepping <= 0xC)) || + ((model == 0x9E) && (stepping >= 0x9) && (stepping <= 0xD)) || + ((model == 0xA6) && (stepping == 0x0)) || + ((model == 0xAE) && (stepping == 0xA)); + } + } + } + + static std::atomic codeOffset; + JitCompilerX86::JitCompilerX86() { - code = (uint8_t*)allocExecutableMemory(CodeSize); + applyTweaks(); + allocatedCode = (uint8_t*)allocExecutableMemory(CodeSize * 2); + // Shift code base address to improve caching - all threads will use different L2/L3 cache sets + code = allocatedCode + (codeOffset.fetch_add(59 * 64) % CodeSize); memcpy(code, codePrologue, prologueSize); memcpy(code + epilogueOffset, codeEpilogue, epilogueSize); } JitCompilerX86::~JitCompilerX86() { - freePagedMemory(code, CodeSize); + freePagedMemory(allocatedCode, CodeSize); } void JitCompilerX86::generateProgram(Program& prog, ProgramConfiguration& pcfg) { @@ -268,8 +356,6 @@ namespace randomx { } void JitCompilerX86::generateProgramPrologue(Program& prog, ProgramConfiguration& pcfg) { - memset(registerUsage, -1, sizeof(registerUsage)); - codePos = ((uint8_t*)randomx_program_prologue_first_load) - ((uint8_t*)randomx_program_prologue); code[codePos + 2] = 0xc0 + pcfg.readReg0; code[codePos + 5] = 0xc0 + pcfg.readReg1; @@ -280,13 +366,21 @@ namespace randomx { memcpy(code + codePos - 48, &pcfg.eMask, sizeof(pcfg.eMask)); memcpy(code + codePos, codeLoopLoad, loopLoadSize); codePos += loopLoadSize; - for (unsigned i = 0; i < prog.getSize(); ++i) { - Instruction& instr = prog(i); - instr.src %= RegistersCount; - instr.dst %= RegistersCount; - instructionOffsets[i] = codePos; - (this->*(engine[instr.opcode]))(instr, i); + + //mark all registers as used + uint64_t* r = (uint64_t*)registerUsage; + uint64_t k = codePos; + k |= k << 32; + for (unsigned j = 0; j < RegistersCount / 2; ++j) { + r[j] = k; } + + for (int i = 0, n = static_cast(RandomX_CurrentConfig.ProgramSize); i < n; ++i) { + Instruction instr = prog(i); + *((uint64_t*)&instr) &= (uint64_t(-1) - (0xFFFF << 8)) | ((RegistersCount - 1) << 8) | ((RegistersCount - 1) << 16); + (this->*(engine[instr.opcode]))(instr); + } + emit(REX_MOV_RR, code, codePos); emitByte(0xc0 + pcfg.readReg2, code, codePos); emit(REX_XOR_EAX, code, codePos); @@ -301,6 +395,22 @@ namespace randomx { emit(RandomX_CurrentConfig.codePrefetchScratchpadTweaked, prefetchScratchpadSize, code, codePos); memcpy(code + codePos, codeLoopStore, loopStoreSize); codePos += loopStoreSize; + + if (BranchesWithin32B) { + const uint32_t branch_begin = static_cast(codePos); + const uint32_t branch_end = static_cast(branch_begin + 9); + + // If the jump crosses or touches 32-byte boundary, align it + if ((branch_begin ^ branch_end) >= 32) { + uint32_t alignment_size = 32 - (branch_begin & 31); + if (alignment_size > 8) { + emit(NOPX[alignment_size - 9], alignment_size - 8, code, codePos); + alignment_size = 8; + } + emit(NOPX[alignment_size - 1], alignment_size, code, codePos); + } + } + emit(SUB_EBX, code, codePos); emit(JNZ, code, codePos); emit32(prologueSize - codePos - 4, code, codePos); @@ -311,103 +421,104 @@ namespace randomx { void JitCompilerX86::generateSuperscalarCode(Instruction& instr, std::vector &reciprocalCache) { switch ((SuperscalarInstructionType)instr.opcode) { - case randomx::SuperscalarInstructionType::ISUB_R: - emit(REX_SUB_RR, code, codePos); - emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); - break; - case randomx::SuperscalarInstructionType::IXOR_R: - emit(REX_XOR_RR, code, codePos); - emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); - break; - case randomx::SuperscalarInstructionType::IADD_RS: - emit(REX_LEA, code, codePos); - emitByte(0x04 + 8 * instr.dst, code, codePos); - genSIB(instr.getModShift(), instr.src, instr.dst, code, codePos); - break; - case randomx::SuperscalarInstructionType::IMUL_R: - emit(REX_IMUL_RR, code, codePos); - emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); - break; - case randomx::SuperscalarInstructionType::IROR_C: - emit(REX_ROT_I8, code, codePos); - emitByte(0xc8 + instr.dst, code, codePos); - emitByte(instr.getImm32() & 63, code, codePos); - break; - case randomx::SuperscalarInstructionType::IADD_C7: - emit(REX_81, code, codePos); - emitByte(0xc0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); - break; - case randomx::SuperscalarInstructionType::IXOR_C7: - emit(REX_XOR_RI, code, codePos); - emitByte(0xf0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); - break; - case randomx::SuperscalarInstructionType::IADD_C8: - emit(REX_81, code, codePos); - emitByte(0xc0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); + case randomx::SuperscalarInstructionType::ISUB_R: + emit(REX_SUB_RR, code, codePos); + emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); + break; + case randomx::SuperscalarInstructionType::IXOR_R: + emit(REX_XOR_RR, code, codePos); + emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); + break; + case randomx::SuperscalarInstructionType::IADD_RS: + emit(REX_LEA, code, codePos); + emitByte(0x04 + 8 * instr.dst, code, codePos); + genSIB(instr.getModShift(), instr.src, instr.dst, code, codePos); + break; + case randomx::SuperscalarInstructionType::IMUL_R: + emit(REX_IMUL_RR, code, codePos); + emitByte(0xc0 + 8 * instr.dst + instr.src, code, codePos); + break; + case randomx::SuperscalarInstructionType::IROR_C: + emit(REX_ROT_I8, code, codePos); + emitByte(0xc8 + instr.dst, code, codePos); + emitByte(instr.getImm32() & 63, code, codePos); + break; + case randomx::SuperscalarInstructionType::IADD_C7: + emit(REX_81, code, codePos); + emitByte(0xc0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); + break; + case randomx::SuperscalarInstructionType::IXOR_C7: + emit(REX_XOR_RI, code, codePos); + emitByte(0xf0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); + break; + case randomx::SuperscalarInstructionType::IADD_C8: + emit(REX_81, code, codePos); + emitByte(0xc0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); #ifdef RANDOMX_ALIGN - emit(NOP1, code, codePos); + emit(NOP1, code, codePos); #endif - break; - case randomx::SuperscalarInstructionType::IXOR_C8: - emit(REX_XOR_RI, code, codePos); - emitByte(0xf0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); + break; + case randomx::SuperscalarInstructionType::IXOR_C8: + emit(REX_XOR_RI, code, codePos); + emitByte(0xf0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); #ifdef RANDOMX_ALIGN - emit(NOP1, code, codePos); + emit(NOP1, code, codePos); #endif - break; - case randomx::SuperscalarInstructionType::IADD_C9: - emit(REX_81, code, codePos); - emitByte(0xc0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); + break; + case randomx::SuperscalarInstructionType::IADD_C9: + emit(REX_81, code, codePos); + emitByte(0xc0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); #ifdef RANDOMX_ALIGN - emit(NOP2, code, codePos); + emit(NOP2, code, codePos); #endif - break; - case randomx::SuperscalarInstructionType::IXOR_C9: - emit(REX_XOR_RI, code, codePos); - emitByte(0xf0 + instr.dst, code, codePos); - emit32(instr.getImm32(), code, codePos); + break; + case randomx::SuperscalarInstructionType::IXOR_C9: + emit(REX_XOR_RI, code, codePos); + emitByte(0xf0 + instr.dst, code, codePos); + emit32(instr.getImm32(), code, codePos); #ifdef RANDOMX_ALIGN - emit(NOP2, code, codePos); + emit(NOP2, code, codePos); #endif - break; - case randomx::SuperscalarInstructionType::IMULH_R: - emit(REX_MOV_RR64, code, codePos); - emitByte(0xc0 + instr.dst, code, codePos); - emit(REX_MUL_R, code, codePos); - emitByte(0xe0 + instr.src, code, codePos); - emit(REX_MOV_R64R, code, codePos); - emitByte(0xc2 + 8 * instr.dst, code, codePos); - break; - case randomx::SuperscalarInstructionType::ISMULH_R: - emit(REX_MOV_RR64, code, codePos); - emitByte(0xc0 + instr.dst, code, codePos); - emit(REX_MUL_R, code, codePos); - emitByte(0xe8 + instr.src, code, codePos); - emit(REX_MOV_R64R, code, codePos); - emitByte(0xc2 + 8 * instr.dst, code, codePos); - break; - case randomx::SuperscalarInstructionType::IMUL_RCP: - emit(MOV_RAX_I, code, codePos); - emit64(reciprocalCache[instr.getImm32()], code, codePos); - emit(REX_IMUL_RM, code, codePos); - emitByte(0xc0 + 8 * instr.dst, code, codePos); - break; - default: - UNREACHABLE; + break; + case randomx::SuperscalarInstructionType::IMULH_R: + emit(REX_MOV_RR64, code, codePos); + emitByte(0xc0 + instr.dst, code, codePos); + emit(REX_MUL_R, code, codePos); + emitByte(0xe0 + instr.src, code, codePos); + emit(REX_MOV_R64R, code, codePos); + emitByte(0xc2 + 8 * instr.dst, code, codePos); + break; + case randomx::SuperscalarInstructionType::ISMULH_R: + emit(REX_MOV_RR64, code, codePos); + emitByte(0xc0 + instr.dst, code, codePos); + emit(REX_MUL_R, code, codePos); + emitByte(0xe8 + instr.src, code, codePos); + emit(REX_MOV_R64R, code, codePos); + emitByte(0xc2 + 8 * instr.dst, code, codePos); + break; + case randomx::SuperscalarInstructionType::IMUL_RCP: + emit(MOV_RAX_I, code, codePos); + emit64(reciprocalCache[instr.getImm32()], code, codePos); + emit(REX_IMUL_RM, code, codePos); + emitByte(0xc0 + 8 * instr.dst, code, codePos); + break; + default: + UNREACHABLE; } } - void JitCompilerX86::genAddressReg(Instruction& instr, uint8_t* code, int& codePos, bool rax) { - emit(LEA_32, code, codePos); - emitByte(0x80 + instr.src + (rax ? 0 : 8), code, codePos); - if (instr.src == RegisterNeedsSib) { - emitByte(0x24, code, codePos); - } + template + FORCE_INLINE void JitCompilerX86::genAddressReg(const Instruction& instr, uint8_t* code, int& codePos) { + const uint32_t src = *((uint32_t*)&instr) & 0xFF0000; + + *(uint32_t*)(code + codePos) = (rax ? 0x24808d41 : 0x24888d41) + src; + codePos += (src == (RegisterNeedsSib << 16)) ? 4 : 3; + emit32(instr.getImm32(), code, codePos); if (rax) emitByte(AND_EAX_I, code, codePos); @@ -416,12 +527,14 @@ namespace randomx { emit32(instr.getModMem() ? ScratchpadL1Mask : ScratchpadL2Mask, code, codePos); } - void JitCompilerX86::genAddressRegDst(Instruction& instr, uint8_t* code, int& codePos) { - emit(LEA_32, code, codePos); - emitByte(0x80 + instr.dst, code, codePos); - if (instr.dst == RegisterNeedsSib) { - emitByte(0x24, code, codePos); - } + template void JitCompilerX86::genAddressReg(const Instruction& instr, uint8_t* code, int& codePos); + template void JitCompilerX86::genAddressReg(const Instruction& instr, uint8_t* code, int& codePos); + + FORCE_INLINE void JitCompilerX86::genAddressRegDst(const Instruction& instr, uint8_t* code, int& codePos) { + const uint32_t dst = static_cast(instr.dst) << 16; + *(uint32_t*)(code + codePos) = 0x24808d41 + dst; + codePos += (dst == (RegisterNeedsSib << 16)) ? 4 : 3; + emit32(instr.getImm32(), code, codePos); emitByte(AND_EAX_I, code, codePos); if (instr.getModCond() < StoreL3Condition) { @@ -432,7 +545,7 @@ namespace randomx { } } - void JitCompilerX86::genAddressImm(Instruction& instr, uint8_t* code, int& codePos) { + FORCE_INLINE void JitCompilerX86::genAddressImm(const Instruction& instr, uint8_t* code, int& codePos) { emit32(instr.getImm32() & ScratchpadL3Mask, code, codePos); } @@ -447,17 +560,18 @@ namespace randomx { 0x3c8d4f, }; - void JitCompilerX86::h_IADD_RS(Instruction& instr, int i) { + void JitCompilerX86::h_IADD_RS(const Instruction& instr) { int pos = codePos; uint8_t* const p = code + pos; - registerUsage[instr.dst] = i; - const uint32_t sib = (instr.getModShift() << 6) | (instr.src << 3) | instr.dst; *(uint32_t*)(p) = template_IADD_RS[instr.dst] | (sib << 24); *(uint32_t*)(p + 4) = instr.getImm32(); - codePos = pos + ((instr.dst == RegisterNeedsDisplacement) ? 8 : 4); + pos += ((instr.dst == RegisterNeedsDisplacement) ? 8 : 4); + + registerUsage[instr.dst] = pos; + codePos = pos; } static const uint32_t template_IADD_M[8] = { @@ -471,13 +585,12 @@ namespace randomx { 0x063c034c, }; - void JitCompilerX86::h_IADD_M(Instruction& instr, int i) { + void JitCompilerX86::h_IADD_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos); + genAddressReg(instr, p, pos); emit32(template_IADD_M[instr.dst], p, pos); } else { @@ -486,6 +599,7 @@ namespace randomx { genAddressImm(instr, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } @@ -493,11 +607,10 @@ namespace randomx { emitByte((scale << 6) | (index << 3) | base, code, codePos); } - void JitCompilerX86::h_ISUB_R(Instruction& instr, int i) { + void JitCompilerX86::h_ISUB_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { emit(REX_SUB_RR, p, pos); emitByte(0xc0 + 8 * instr.dst + instr.src, p, pos); @@ -508,16 +621,16 @@ namespace randomx { emit32(instr.getImm32(), p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_ISUB_M(Instruction& instr, int i) { + void JitCompilerX86::h_ISUB_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos); + genAddressReg(instr, p, pos); emit(REX_SUB_RM, p, pos); emitByte(0x04 + 8 * instr.dst, p, pos); emitByte(0x06, p, pos); @@ -528,14 +641,14 @@ namespace randomx { genAddressImm(instr, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IMUL_R(Instruction& instr, int i) { + void JitCompilerX86::h_IMUL_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { emit(REX_IMUL_RR, p, pos); emitByte(0xc0 + 8 * instr.dst + instr.src, p, pos); @@ -546,16 +659,16 @@ namespace randomx { emit32(instr.getImm32(), p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IMUL_M(Instruction& instr, int i) { + void JitCompilerX86::h_IMUL_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos); + genAddressReg(instr, p, pos); emit(REX_IMUL_RM, p, pos); emitByte(0x04 + 8 * instr.dst, p, pos); emitByte(0x06, p, pos); @@ -566,14 +679,14 @@ namespace randomx { genAddressImm(instr, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IMULH_R(Instruction& instr, int i) { + void JitCompilerX86::h_IMULH_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - registerUsage[instr.dst] = i; emit(REX_MOV_RR64, p, pos); emitByte(0xc0 + instr.dst, p, pos); emit(REX_MUL_R, p, pos); @@ -581,16 +694,16 @@ namespace randomx { emit(REX_MOV_R64R, p, pos); emitByte(0xc2 + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IMULH_M(Instruction& instr, int i) { + void JitCompilerX86::h_IMULH_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos, false); + genAddressReg(instr, p, pos); emit(REX_MOV_RR64, p, pos); emitByte(0xc0 + instr.dst, p, pos); emit(REX_MUL_MEM, p, pos); @@ -605,14 +718,14 @@ namespace randomx { emit(REX_MOV_R64R, p, pos); emitByte(0xc2 + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_ISMULH_R(Instruction& instr, int i) { + void JitCompilerX86::h_ISMULH_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + emit(REX_MOV_RR64, p, pos); emitByte(0xc0 + instr.dst, p, pos); emit(REX_MUL_R, p, pos); @@ -620,16 +733,16 @@ namespace randomx { emit(REX_MOV_R64R, p, pos); emitByte(0xc2 + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_ISMULH_M(Instruction& instr, int i) { + void JitCompilerX86::h_ISMULH_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos, false); + genAddressReg(instr, p, pos); emit(REX_MOV_RR64, p, pos); emitByte(0xc0 + instr.dst, p, pos); emit(REX_IMUL_MEM, p, pos); @@ -644,41 +757,41 @@ namespace randomx { emit(REX_MOV_R64R, p, pos); emitByte(0xc2 + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IMUL_RCP(Instruction& instr, int i) { + void JitCompilerX86::h_IMUL_RCP(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - + uint64_t divisor = instr.getImm32(); if (!isZeroOrPowerOf2(divisor)) { - registerUsage[instr.dst] = i; emit(MOV_RAX_I, p, pos); emit64(randomx_reciprocal_fast(divisor), p, pos); emit(REX_IMUL_RM, p, pos); emitByte(0xc0 + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; } codePos = pos; } - void JitCompilerX86::h_INEG_R(Instruction& instr, int i) { + void JitCompilerX86::h_INEG_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + emit(REX_NEG, p, pos); emitByte(0xd8 + instr.dst, p, pos); + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IXOR_R(Instruction& instr, int i) { + void JitCompilerX86::h_IXOR_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { emit(REX_XOR_RR, p, pos); emitByte(0xc0 + 8 * instr.dst + instr.src, p, pos); @@ -689,16 +802,16 @@ namespace randomx { emit32(instr.getImm32(), p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IXOR_M(Instruction& instr, int i) { + void JitCompilerX86::h_IXOR_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { - genAddressReg(instr, p, pos); + genAddressReg(instr, p, pos); emit(REX_XOR_RM, p, pos); emitByte(0x04 + 8 * instr.dst, p, pos); emitByte(0x06, p, pos); @@ -709,14 +822,14 @@ namespace randomx { genAddressImm(instr, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IROR_R(Instruction& instr, int i) { + void JitCompilerX86::h_IROR_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - registerUsage[instr.dst] = i; + if (instr.src != instr.dst) { emit(REX_MOV_RR, p, pos); emitByte(0xc8 + instr.src, p, pos); @@ -729,14 +842,14 @@ namespace randomx { emitByte(instr.getImm32() & 63, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_IROL_R(Instruction& instr, int i) { + void JitCompilerX86::h_IROL_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - registerUsage[instr.dst] = i; if (instr.src != instr.dst) { emit(REX_MOV_RR, p, pos); emitByte(0xc8 + instr.src, p, pos); @@ -749,27 +862,28 @@ namespace randomx { emitByte(instr.getImm32() & 63, p, pos); } + registerUsage[instr.dst] = pos; codePos = pos; } - void JitCompilerX86::h_ISWAP_R(Instruction& instr, int i) { + void JitCompilerX86::h_ISWAP_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - + if (instr.src != instr.dst) { - registerUsage[instr.dst] = i; - registerUsage[instr.src] = i; emit(REX_XCHG, p, pos); emitByte(0xc0 + instr.src + 8 * instr.dst, p, pos); + registerUsage[instr.dst] = pos; + registerUsage[instr.src] = pos; } codePos = pos; } - void JitCompilerX86::h_FSWAP_R(Instruction& instr, int i) { + void JitCompilerX86::h_FSWAP_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - + emit(SHUFPD, p, pos); emitByte(0xc0 + 9 * instr.dst, p, pos); emitByte(1, p, pos); @@ -777,105 +891,105 @@ namespace randomx { codePos = pos; } - void JitCompilerX86::h_FADD_R(Instruction& instr, int i) { + void JitCompilerX86::h_FADD_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - instr.dst %= RegisterCountFlt; - instr.src %= RegisterCountFlt; + const uint32_t dst = instr.dst % RegisterCountFlt; + const uint32_t src = instr.src % RegisterCountFlt; emit(REX_ADDPD, p, pos); - emitByte(0xc0 + instr.src + 8 * instr.dst, p, pos); + emitByte(0xc0 + src + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FADD_M(Instruction& instr, int i) { + void JitCompilerX86::h_FADD_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; - genAddressReg(instr, p, pos); + + const uint32_t dst = instr.dst % RegisterCountFlt; + genAddressReg(instr, p, pos); emit(REX_CVTDQ2PD_XMM12, p, pos); emit(REX_ADDPD, p, pos); - emitByte(0xc4 + 8 * instr.dst, p, pos); + emitByte(0xc4 + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FSUB_R(Instruction& instr, int i) { + void JitCompilerX86::h_FSUB_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; - instr.src %= RegisterCountFlt; + + const uint32_t dst = instr.dst % RegisterCountFlt; + const uint32_t src = instr.src % RegisterCountFlt; emit(REX_SUBPD, p, pos); - emitByte(0xc0 + instr.src + 8 * instr.dst, p, pos); + emitByte(0xc0 + src + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FSUB_M(Instruction& instr, int i) { + void JitCompilerX86::h_FSUB_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; - genAddressReg(instr, p, pos); + + const uint32_t dst = instr.dst % RegisterCountFlt; + genAddressReg(instr, p, pos); emit(REX_CVTDQ2PD_XMM12, p, pos); emit(REX_SUBPD, p, pos); - emitByte(0xc4 + 8 * instr.dst, p, pos); + emitByte(0xc4 + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FSCAL_R(Instruction& instr, int i) { + void JitCompilerX86::h_FSCAL_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; + + const uint32_t dst = instr.dst % RegisterCountFlt; emit(REX_XORPS, p, pos); - emitByte(0xc7 + 8 * instr.dst, p, pos); + emitByte(0xc7 + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FMUL_R(Instruction& instr, int i) { + void JitCompilerX86::h_FMUL_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; - instr.src %= RegisterCountFlt; + + const uint32_t dst = instr.dst % RegisterCountFlt; + const uint32_t src = instr.src % RegisterCountFlt; emit(REX_MULPD, p, pos); - emitByte(0xe0 + instr.src + 8 * instr.dst, p, pos); + emitByte(0xe0 + src + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FDIV_M(Instruction& instr, int i) { + void JitCompilerX86::h_FDIV_M(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; - genAddressReg(instr, p, pos); + + const uint32_t dst = instr.dst % RegisterCountFlt; + genAddressReg(instr, p, pos); emit(REX_CVTDQ2PD_XMM12, p, pos); emit(REX_ANDPS_XMM12, p, pos); emit(REX_DIVPD, p, pos); - emitByte(0xe4 + 8 * instr.dst, p, pos); + emitByte(0xe4 + 8 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_FSQRT_R(Instruction& instr, int i) { + void JitCompilerX86::h_FSQRT_R(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - instr.dst %= RegisterCountFlt; + + const uint32_t dst = instr.dst % RegisterCountFlt; emit(SQRTPD, p, pos); - emitByte(0xe4 + 9 * instr.dst, p, pos); + emitByte(0xe4 + 9 * dst, p, pos); codePos = pos; } - void JitCompilerX86::h_CFROUND(Instruction& instr, int i) { + void JitCompilerX86::h_CFROUND(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; @@ -891,27 +1005,46 @@ namespace randomx { codePos = pos; } - void JitCompilerX86::h_CBRANCH(Instruction& instr, int i) { + void JitCompilerX86::h_CBRANCH(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; - - int reg = instr.dst; - int target = registerUsage[reg] + 1; + + const int reg = instr.dst; + int32_t jmp_offset = registerUsage[reg] - (pos + 16); + + if (BranchesWithin32B) { + const uint32_t branch_begin = static_cast(pos + 7); + const uint32_t branch_end = static_cast(branch_begin + ((jmp_offset >= -128) ? 9 : 13)); + + // If the jump crosses or touches 32-byte boundary, align it + if ((branch_begin ^ branch_end) >= 32) { + const uint32_t alignment_size = 32 - (branch_begin & 31); + jmp_offset -= alignment_size; + emit(JMP_ALIGN_PREFIX[alignment_size], alignment_size, p, pos); + } + } + emit(REX_ADD_I, p, pos); emitByte(0xc0 + reg, p, pos); - int shift = instr.getModCond() + RandomX_CurrentConfig.JumpOffset; - uint32_t imm = instr.getImm32() | (1UL << shift); - if (RandomX_CurrentConfig.JumpOffset > 0 || shift > 0) - imm &= ~(1UL << (shift - 1)); + const int shift = instr.getModCond() + RandomX_CurrentConfig.JumpOffset; + const uint32_t imm = (instr.getImm32() | (1UL << shift)) & ~(1UL << (shift - 1)); emit32(imm, p, pos); emit(REX_TEST, p, pos); emitByte(0xc0 + reg, p, pos); emit32(RandomX_CurrentConfig.ConditionMask_Calculated << shift, p, pos); - emit(JZ, p, pos); - emit32(instructionOffsets[target] - (pos + 4), p, pos); + + if (jmp_offset >= -128) { + emitByte(JZ_SHORT, p, pos); + emitByte(jmp_offset, p, pos); + } + else { + emit(JZ, p, pos); + emit32(jmp_offset - 4, p, pos); + } + //mark all registers as used uint64_t* r = (uint64_t*) registerUsage; - uint64_t k = i; + uint64_t k = pos; k |= k << 32; for (unsigned j = 0; j < RegistersCount / 2; ++j) { r[j] = k; @@ -920,7 +1053,7 @@ namespace randomx { codePos = pos; } - void JitCompilerX86::h_ISTORE(Instruction& instr, int i) { + void JitCompilerX86::h_ISTORE(const Instruction& instr) { uint8_t* const p = code; int pos = codePos; @@ -932,10 +1065,10 @@ namespace randomx { codePos = pos; } - void JitCompilerX86::h_NOP(Instruction& instr, int i) { + void JitCompilerX86::h_NOP(const Instruction& instr) { emit(NOP1, code, codePos); } InstructionGeneratorX86 JitCompilerX86::engine[256] = {}; -} +} \ No newline at end of file diff --git a/src/crypto/randomx/jit_compiler_x86.hpp b/src/crypto/randomx/jit_compiler_x86.hpp index 30e7c281..64e162e2 100644 --- a/src/crypto/randomx/jit_compiler_x86.hpp +++ b/src/crypto/randomx/jit_compiler_x86.hpp @@ -36,12 +36,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. namespace randomx { class Program; - class ProgramConfiguration; + struct ProgramConfiguration; class SuperscalarProgram; class JitCompilerX86; class Instruction; - typedef void(JitCompilerX86::*InstructionGeneratorX86)(Instruction&, int); + typedef void(JitCompilerX86::*InstructionGeneratorX86)(const Instruction&); constexpr uint32_t CodeSize = 64 * 1024; @@ -66,16 +66,20 @@ namespace randomx { size_t getCodeSize(); static InstructionGeneratorX86 engine[256]; - int32_t instructionOffsets[512]; int registerUsage[RegistersCount]; + uint8_t* allocatedCode; uint8_t* code; int32_t codePos; + static bool BranchesWithin32B; + + static void applyTweaks(); void generateProgramPrologue(Program&, ProgramConfiguration&); void generateProgramEpilogue(Program&, ProgramConfiguration&); - static void genAddressReg(Instruction&, uint8_t* code, int& codePos, bool rax = true); - static void genAddressRegDst(Instruction&, uint8_t* code, int& codePos); - static void genAddressImm(Instruction&, uint8_t* code, int& codePos); + template + static void genAddressReg(const Instruction&, uint8_t* code, int& codePos); + static void genAddressRegDst(const Instruction&, uint8_t* code, int& codePos); + static void genAddressImm(const Instruction&, uint8_t* code, int& codePos); static void genSIB(int scale, int index, int base, uint8_t* code, int& codePos); void generateSuperscalarCode(Instruction &, std::vector &); @@ -105,36 +109,36 @@ namespace randomx { codePos += count; } - void h_IADD_RS(Instruction&, int); - void h_IADD_M(Instruction&, int); - void h_ISUB_R(Instruction&, int); - void h_ISUB_M(Instruction&, int); - void h_IMUL_R(Instruction&, int); - void h_IMUL_M(Instruction&, int); - void h_IMULH_R(Instruction&, int); - void h_IMULH_M(Instruction&, int); - void h_ISMULH_R(Instruction&, int); - void h_ISMULH_M(Instruction&, int); - void h_IMUL_RCP(Instruction&, int); - void h_INEG_R(Instruction&, int); - void h_IXOR_R(Instruction&, int); - void h_IXOR_M(Instruction&, int); - void h_IROR_R(Instruction&, int); - void h_IROL_R(Instruction&, int); - void h_ISWAP_R(Instruction&, int); - void h_FSWAP_R(Instruction&, int); - void h_FADD_R(Instruction&, int); - void h_FADD_M(Instruction&, int); - void h_FSUB_R(Instruction&, int); - void h_FSUB_M(Instruction&, int); - void h_FSCAL_R(Instruction&, int); - void h_FMUL_R(Instruction&, int); - void h_FDIV_M(Instruction&, int); - void h_FSQRT_R(Instruction&, int); - void h_CBRANCH(Instruction&, int); - void h_CFROUND(Instruction&, int); - void h_ISTORE(Instruction&, int); - void h_NOP(Instruction&, int); + void h_IADD_RS(const Instruction&); + void h_IADD_M(const Instruction&); + void h_ISUB_R(const Instruction&); + void h_ISUB_M(const Instruction&); + void h_IMUL_R(const Instruction&); + void h_IMUL_M(const Instruction&); + void h_IMULH_R(const Instruction&); + void h_IMULH_M(const Instruction&); + void h_ISMULH_R(const Instruction&); + void h_ISMULH_M(const Instruction&); + void h_IMUL_RCP(const Instruction&); + void h_INEG_R(const Instruction&); + void h_IXOR_R(const Instruction&); + void h_IXOR_M(const Instruction&); + void h_IROR_R(const Instruction&); + void h_IROL_R(const Instruction&); + void h_ISWAP_R(const Instruction&); + void h_FSWAP_R(const Instruction&); + void h_FADD_R(const Instruction&); + void h_FADD_M(const Instruction&); + void h_FSUB_R(const Instruction&); + void h_FSUB_M(const Instruction&); + void h_FSCAL_R(const Instruction&); + void h_FMUL_R(const Instruction&); + void h_FDIV_M(const Instruction&); + void h_FSQRT_R(const Instruction&); + void h_CBRANCH(const Instruction&); + void h_CFROUND(const Instruction&); + void h_ISTORE(const Instruction&); + void h_NOP(const Instruction&); }; -} +} \ No newline at end of file diff --git a/src/crypto/randomx/randomx.cpp b/src/crypto/randomx/randomx.cpp index d0e12b86..56df81f4 100644 --- a/src/crypto/randomx/randomx.cpp +++ b/src/crypto/randomx/randomx.cpp @@ -26,6 +26,7 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include "crypto/randomx/common.hpp" #include "crypto/randomx/randomx.h" #include "crypto/randomx/dataset.hpp" #include "crypto/randomx/vm_interpreted.hpp" @@ -33,7 +34,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "crypto/randomx/vm_compiled.hpp" #include "crypto/randomx/vm_compiled_light.hpp" #include "crypto/randomx/blake2/blake2.h" + +#if defined(_M_X64) || defined(__x86_64__) #include "crypto/randomx/jit_compiler_x86_static.hpp" +#elif defined(XMRIG_ARMv8) +#include "crypto/randomx/jit_compiler_a64_static.hpp" +#endif + #include RandomX_ConfigurationWownero::RandomX_ConfigurationWownero() @@ -85,6 +92,16 @@ RandomX_ConfigurationArqma::RandomX_ConfigurationArqma() ScratchpadL3_Size = 262144; } +RandomX_ConfigurationSafex::RandomX_ConfigurationSafex() +{ + ArgonIterations = 3; + ArgonSalt = "RandomSFX\x01"; + ProgramIterations = 2048; + ProgramCount = 8; + ScratchpadL2_Size = 262144; + ScratchpadL3_Size = 2097152; +} + RandomX_ConfigurationBase::RandomX_ConfigurationBase() : ArgonMemory(262144) , ArgonIterations(3) @@ -166,19 +183,10 @@ RandomX_ConfigurationBase::RandomX_ConfigurationBase() #endif } +static uint32_t Log2(size_t value) { return (value > 1) ? (Log2(value / 2) + 1) : 0; } + void RandomX_ConfigurationBase::Apply() { -#if defined(_M_X64) || defined(__x86_64__) - *(uint32_t*)(codeShhPrefetchTweaked + 3) = ArgonMemory * 16 - 1; - const uint32_t DatasetBaseMask = DatasetBaseSize - RANDOMX_DATASET_ITEM_SIZE; - *(uint32_t*)(codeReadDatasetTweaked + 7) = DatasetBaseMask; - *(uint32_t*)(codeReadDatasetTweaked + 23) = DatasetBaseMask; - *(uint32_t*)(codeReadDatasetLightSshInitTweaked + 59) = DatasetBaseMask; -#endif - - CacheLineAlignMask_Calculated = (DatasetBaseSize - 1) & ~(RANDOMX_DATASET_ITEM_SIZE - 1); - DatasetExtraItems_Calculated = DatasetExtraSize / RANDOMX_DATASET_ITEM_SIZE; - ScratchpadL1Mask_Calculated = (ScratchpadL1_Size / sizeof(uint64_t) - 1) * 8; ScratchpadL1Mask16_Calculated = (ScratchpadL1_Size / sizeof(uint64_t) / 2 - 1) * 16; ScratchpadL2Mask_Calculated = (ScratchpadL2_Size / sizeof(uint64_t) - 1) * 8; @@ -186,22 +194,40 @@ void RandomX_ConfigurationBase::Apply() ScratchpadL3Mask_Calculated = (((ScratchpadL3_Size / sizeof(uint64_t)) - 1) * 8); ScratchpadL3Mask64_Calculated = ((ScratchpadL3_Size / sizeof(uint64_t)) / 8 - 1) * 64; -#if defined(_M_X64) || defined(__x86_64__) - *(uint32_t*)(codePrefetchScratchpadTweaked + 4) = ScratchpadL3Mask64_Calculated; - *(uint32_t*)(codePrefetchScratchpadTweaked + 18) = ScratchpadL3Mask64_Calculated; -#endif + CacheLineAlignMask_Calculated = (DatasetBaseSize - 1) & ~(RANDOMX_DATASET_ITEM_SIZE - 1); + DatasetExtraItems_Calculated = DatasetExtraSize / RANDOMX_DATASET_ITEM_SIZE; ConditionMask_Calculated = (1 << JumpBits) - 1; - constexpr int CEIL_NULL = 0; - int k = 0; - #if defined(_M_X64) || defined(__x86_64__) + *(uint32_t*)(codeShhPrefetchTweaked + 3) = ArgonMemory * 16 - 1; + const uint32_t DatasetBaseMask = DatasetBaseSize - RANDOMX_DATASET_ITEM_SIZE; + *(uint32_t*)(codeReadDatasetTweaked + 7) = DatasetBaseMask; + *(uint32_t*)(codeReadDatasetTweaked + 23) = DatasetBaseMask; + *(uint32_t*)(codeReadDatasetLightSshInitTweaked + 59) = DatasetBaseMask; + + *(uint32_t*)(codePrefetchScratchpadTweaked + 4) = ScratchpadL3Mask64_Calculated; + *(uint32_t*)(codePrefetchScratchpadTweaked + 18) = ScratchpadL3Mask64_Calculated; + #define JIT_HANDLE(x, prev) randomx::JitCompilerX86::engine[k] = &randomx::JitCompilerX86::h_##x + +#elif defined(XMRIG_ARMv8) + + Log2_ScratchpadL1 = Log2(ScratchpadL1_Size); + Log2_ScratchpadL2 = Log2(ScratchpadL2_Size); + Log2_ScratchpadL3 = Log2(ScratchpadL3_Size); + Log2_DatasetBaseSize = Log2(DatasetBaseSize); + Log2_CacheSize = Log2((ArgonMemory * randomx::ArgonBlockSize) / randomx::CacheLineSize); + +#define JIT_HANDLE(x, prev) randomx::JitCompilerA64::engine[k] = &randomx::JitCompilerA64::h_##x + #else #define JIT_HANDLE(x, prev) #endif + constexpr int CEIL_NULL = 0; + int k = 0; + #define INST_HANDLE(x, prev) \ CEIL_##x = CEIL_##prev + RANDOMX_FREQ_##x; \ for (; k < CEIL_##x; ++k) { JIT_HANDLE(x, prev); } @@ -243,218 +269,236 @@ RandomX_ConfigurationMonero RandomX_MoneroConfig; RandomX_ConfigurationWownero RandomX_WowneroConfig; RandomX_ConfigurationLoki RandomX_LokiConfig; RandomX_ConfigurationArqma RandomX_ArqmaConfig; - +RandomX_ConfigurationSafex RandomX_SafexConfig; RandomX_ConfigurationBase RandomX_CurrentConfig; extern "C" { - randomx_cache *randomx_alloc_cache(randomx_flags flags) { - randomx_cache *cache = nullptr; +randomx_cache *randomx_alloc_cache(randomx_flags flags) { + randomx_cache *cache = nullptr; - try { - cache = new randomx_cache(); - switch (flags & (RANDOMX_FLAG_JIT | RANDOMX_FLAG_LARGE_PAGES)) { - case RANDOMX_FLAG_DEFAULT: - cache->dealloc = &randomx::deallocCache; - cache->jit = nullptr; - cache->initialize = &randomx::initCache; - cache->datasetInit = &randomx::initDataset; - cache->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); - break; + try { + cache = new randomx_cache(); + switch (flags & (RANDOMX_FLAG_JIT | RANDOMX_FLAG_LARGE_PAGES)) { + case RANDOMX_FLAG_DEFAULT: + cache->dealloc = &randomx::deallocCache; + cache->jit = nullptr; + cache->initialize = &randomx::initCache; + cache->datasetInit = &randomx::initDataset; + cache->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); + break; - case RANDOMX_FLAG_JIT: - cache->dealloc = &randomx::deallocCache; - cache->jit = new randomx::JitCompiler(); - cache->initialize = &randomx::initCacheCompile; - cache->datasetInit = cache->jit->getDatasetInitFunc(); - cache->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); - break; + case RANDOMX_FLAG_JIT: + cache->dealloc = &randomx::deallocCache; + cache->jit = new randomx::JitCompiler(); + cache->initialize = &randomx::initCacheCompile; + cache->datasetInit = cache->jit->getDatasetInitFunc(); + cache->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); + break; - case RANDOMX_FLAG_LARGE_PAGES: - cache->dealloc = &randomx::deallocCache; - cache->jit = nullptr; - cache->initialize = &randomx::initCache; - cache->datasetInit = &randomx::initDataset; - cache->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); - break; + case RANDOMX_FLAG_LARGE_PAGES: + cache->dealloc = &randomx::deallocCache; + cache->jit = nullptr; + cache->initialize = &randomx::initCache; + cache->datasetInit = &randomx::initDataset; + cache->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); + break; - case RANDOMX_FLAG_JIT | RANDOMX_FLAG_LARGE_PAGES: - cache->dealloc = &randomx::deallocCache; - cache->jit = new randomx::JitCompiler(); - cache->initialize = &randomx::initCacheCompile; - cache->datasetInit = cache->jit->getDatasetInitFunc(); - cache->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); - break; + case RANDOMX_FLAG_JIT | RANDOMX_FLAG_LARGE_PAGES: + cache->dealloc = &randomx::deallocCache; + cache->jit = new randomx::JitCompiler(); + cache->initialize = &randomx::initCacheCompile; + cache->datasetInit = cache->jit->getDatasetInitFunc(); + cache->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_CACHE_MAX_SIZE); + break; - default: - UNREACHABLE; - } + default: + UNREACHABLE; } - catch (std::exception &ex) { - if (cache != nullptr) { - randomx_release_cache(cache); - cache = nullptr; - } + } + catch (std::exception &ex) { + if (cache != nullptr) { + randomx_release_cache(cache); + cache = nullptr; } - - return cache; - } - - void randomx_init_cache(randomx_cache *cache, const void *key, size_t keySize) { - assert(cache != nullptr); - assert(keySize == 0 || key != nullptr); - cache->initialize(cache, key, keySize); - } - - void randomx_release_cache(randomx_cache* cache) { - assert(cache != nullptr); - cache->dealloc(cache); - delete cache; - } - - randomx_dataset *randomx_alloc_dataset(randomx_flags flags) { - randomx_dataset *dataset = nullptr; - - try { - dataset = new randomx_dataset(); - if (flags & RANDOMX_FLAG_LARGE_PAGES) { - dataset->dealloc = &randomx::deallocDataset; - dataset->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_DATASET_MAX_SIZE); - } - else { - dataset->dealloc = &randomx::deallocDataset; - dataset->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_DATASET_MAX_SIZE); - } - } - catch (std::exception &ex) { - if (dataset != nullptr) { - randomx_release_dataset(dataset); - dataset = nullptr; - } - } - - return dataset; - } - - #define DatasetItemCount ((RandomX_CurrentConfig.DatasetBaseSize + RandomX_CurrentConfig.DatasetExtraSize) / RANDOMX_DATASET_ITEM_SIZE) - - unsigned long randomx_dataset_item_count() { - return DatasetItemCount; - } - - void randomx_init_dataset(randomx_dataset *dataset, randomx_cache *cache, unsigned long startItem, unsigned long itemCount) { - assert(dataset != nullptr); - assert(cache != nullptr); - assert(startItem < DatasetItemCount && itemCount <= DatasetItemCount); - assert(startItem + itemCount <= DatasetItemCount); - cache->datasetInit(cache, dataset->memory + startItem * randomx::CacheLineSize, startItem, startItem + itemCount); - } - - void *randomx_get_dataset_memory(randomx_dataset *dataset) { - assert(dataset != nullptr); - return dataset->memory; - } - - void randomx_release_dataset(randomx_dataset *dataset) { - assert(dataset != nullptr); - dataset->dealloc(dataset); - delete dataset; - } - - randomx_vm *randomx_create_vm(randomx_flags flags, randomx_cache *cache, randomx_dataset *dataset, uint8_t *scratchpad) { - assert(cache != nullptr || (flags & RANDOMX_FLAG_FULL_MEM)); - assert(cache == nullptr || cache->isInitialized()); - assert(dataset != nullptr || !(flags & RANDOMX_FLAG_FULL_MEM)); - - randomx_vm *vm = nullptr; - - try { - switch (flags & (RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES)) { - case RANDOMX_FLAG_DEFAULT: - vm = new randomx::InterpretedLightVmDefault(); - break; - - case RANDOMX_FLAG_FULL_MEM: - vm = new randomx::InterpretedVmDefault(); - break; - - case RANDOMX_FLAG_JIT: - vm = new randomx::CompiledLightVmDefault(); - break; - - case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT: - vm = new randomx::CompiledVmDefault(); - break; - - case RANDOMX_FLAG_HARD_AES: - vm = new randomx::InterpretedLightVmHardAes(); - break; - - case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_HARD_AES: - vm = new randomx::InterpretedVmHardAes(); - break; - - case RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES: - vm = new randomx::CompiledLightVmHardAes(); - break; - - case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES: - vm = new randomx::CompiledVmHardAes(); - break; - - default: - UNREACHABLE; - } - - if (cache != nullptr) { - vm->setCache(cache); - } - - if (dataset != nullptr) { - vm->setDataset(dataset); - } - - vm->setScratchpad(scratchpad); - } - catch (std::exception &ex) { - delete vm; - vm = nullptr; - } - - return vm; - } - - void randomx_vm_set_cache(randomx_vm *machine, randomx_cache* cache) { - assert(machine != nullptr); - assert(cache != nullptr && cache->isInitialized()); - machine->setCache(cache); - } - - void randomx_vm_set_dataset(randomx_vm *machine, randomx_dataset *dataset) { - assert(machine != nullptr); - assert(dataset != nullptr); - machine->setDataset(dataset); - } - - void randomx_destroy_vm(randomx_vm *machine) { - assert(machine != nullptr); - delete machine; - } - - void randomx_calculate_hash(randomx_vm *machine, const void *input, size_t inputSize, void *output) { - assert(machine != nullptr); - assert(inputSize == 0 || input != nullptr); - assert(output != nullptr); - alignas(16) uint64_t tempHash[8]; - rx_blake2b(tempHash, sizeof(tempHash), input, inputSize, nullptr, 0); - machine->initScratchpad(&tempHash); - machine->resetRoundingMode(); - for (uint32_t chain = 0; chain < RandomX_CurrentConfig.ProgramCount - 1; ++chain) { - machine->run(&tempHash); - rx_blake2b(tempHash, sizeof(tempHash), machine->getRegisterFile(), sizeof(randomx::RegisterFile), nullptr, 0); - } - machine->run(&tempHash); - machine->getFinalResult(output, RANDOMX_HASH_SIZE); } + return cache; } + +void randomx_init_cache(randomx_cache *cache, const void *key, size_t keySize) { + assert(cache != nullptr); + assert(keySize == 0 || key != nullptr); + cache->initialize(cache, key, keySize); +} + +void randomx_release_cache(randomx_cache* cache) { + assert(cache != nullptr); + cache->dealloc(cache); + delete cache; +} + +randomx_dataset *randomx_alloc_dataset(randomx_flags flags) { + randomx_dataset *dataset = nullptr; + + try { + dataset = new randomx_dataset(); + if (flags & RANDOMX_FLAG_LARGE_PAGES) { + dataset->dealloc = &randomx::deallocDataset; + dataset->memory = (uint8_t*)randomx::LargePageAllocator::allocMemory(RANDOMX_DATASET_MAX_SIZE); + } + else { + dataset->dealloc = &randomx::deallocDataset; + dataset->memory = (uint8_t*)randomx::DefaultAllocator::allocMemory(RANDOMX_DATASET_MAX_SIZE); + } + } + catch (std::exception &ex) { + if (dataset != nullptr) { + randomx_release_dataset(dataset); + dataset = nullptr; + } + } + + return dataset; +} + +#define DatasetItemCount ((RandomX_CurrentConfig.DatasetBaseSize + RandomX_CurrentConfig.DatasetExtraSize) / RANDOMX_DATASET_ITEM_SIZE) + +unsigned long randomx_dataset_item_count() { + return DatasetItemCount; +} + +void randomx_init_dataset(randomx_dataset *dataset, randomx_cache *cache, unsigned long startItem, unsigned long itemCount) { + assert(dataset != nullptr); + assert(cache != nullptr); + assert(startItem < DatasetItemCount && itemCount <= DatasetItemCount); + assert(startItem + itemCount <= DatasetItemCount); + cache->datasetInit(cache, dataset->memory + startItem * randomx::CacheLineSize, startItem, startItem + itemCount); +} + +void *randomx_get_dataset_memory(randomx_dataset *dataset) { + assert(dataset != nullptr); + return dataset->memory; +} + +void randomx_release_dataset(randomx_dataset *dataset) { + assert(dataset != nullptr); + dataset->dealloc(dataset); + delete dataset; +} + +randomx_vm *randomx_create_vm(randomx_flags flags, randomx_cache *cache, randomx_dataset *dataset, uint8_t *scratchpad) { + assert(cache != nullptr || (flags & RANDOMX_FLAG_FULL_MEM)); + assert(cache == nullptr || cache->isInitialized()); + assert(dataset != nullptr || !(flags & RANDOMX_FLAG_FULL_MEM)); + + randomx_vm *vm = nullptr; + + try { + switch (flags & (RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES)) { + case RANDOMX_FLAG_DEFAULT: + vm = new randomx::InterpretedLightVmDefault(); + break; + + case RANDOMX_FLAG_FULL_MEM: + vm = new randomx::InterpretedVmDefault(); + break; + + case RANDOMX_FLAG_JIT: + vm = new randomx::CompiledLightVmDefault(); + break; + + case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT: + vm = new randomx::CompiledVmDefault(); + break; + + case RANDOMX_FLAG_HARD_AES: + vm = new randomx::InterpretedLightVmHardAes(); + break; + + case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_HARD_AES: + vm = new randomx::InterpretedVmHardAes(); + break; + + case RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES: + vm = new randomx::CompiledLightVmHardAes(); + break; + + case RANDOMX_FLAG_FULL_MEM | RANDOMX_FLAG_JIT | RANDOMX_FLAG_HARD_AES: + vm = new randomx::CompiledVmHardAes(); + break; + + default: + UNREACHABLE; + } + + if (cache != nullptr) { + vm->setCache(cache); + } + + if (dataset != nullptr) { + vm->setDataset(dataset); + } + + vm->setScratchpad(scratchpad); + } + catch (std::exception &ex) { + delete vm; + vm = nullptr; + } + + return vm; +} + +void randomx_vm_set_cache(randomx_vm *machine, randomx_cache* cache) { + assert(machine != nullptr); + assert(cache != nullptr && cache->isInitialized()); + machine->setCache(cache); +} + +void randomx_vm_set_dataset(randomx_vm *machine, randomx_dataset *dataset) { + assert(machine != nullptr); + assert(dataset != nullptr); + machine->setDataset(dataset); +} + +void randomx_destroy_vm(randomx_vm *machine) { + assert(machine != nullptr); + delete machine; +} + +void randomx_calculate_hash(randomx_vm *machine, const void *input, size_t inputSize, void *output) { + assert(machine != nullptr); + assert(inputSize == 0 || input != nullptr); + assert(output != nullptr); + alignas(16) uint64_t tempHash[8]; + rx_blake2b(tempHash, sizeof(tempHash), input, inputSize, nullptr, 0); + machine->initScratchpad(&tempHash); + machine->resetRoundingMode(); + for (uint32_t chain = 0; chain < RandomX_CurrentConfig.ProgramCount - 1; ++chain) { + machine->run(&tempHash); + rx_blake2b(tempHash, sizeof(tempHash), machine->getRegisterFile(), sizeof(randomx::RegisterFile), nullptr, 0); + } + machine->run(&tempHash); + machine->getFinalResult(output, RANDOMX_HASH_SIZE); +} + +void randomx_calculate_hash_first(randomx_vm* machine, uint64_t (&tempHash)[8], const void* input, size_t inputSize) { + rx_blake2b(tempHash, sizeof(tempHash), input, inputSize, nullptr, 0); + machine->initScratchpad(tempHash); +} + +void randomx_calculate_hash_next(randomx_vm* machine, uint64_t (&tempHash)[8], const void* nextInput, size_t nextInputSize, void* output) { + machine->resetRoundingMode(); + for (uint32_t chain = 0; chain < RandomX_CurrentConfig.ProgramCount - 1; ++chain) { + machine->run(&tempHash); + rx_blake2b(tempHash, sizeof(tempHash), machine->getRegisterFile(), sizeof(randomx::RegisterFile), nullptr, 0); + } + machine->run(&tempHash); + + // Finish current hash and fill the scratchpad for the next hash at the same time + rx_blake2b(tempHash, sizeof(tempHash), nextInput, nextInputSize, nullptr, 0); + machine->hashAndFill(output, RANDOMX_HASH_SIZE, tempHash); +} + +} \ No newline at end of file diff --git a/src/crypto/randomx/randomx.h b/src/crypto/randomx/randomx.h index 48203ae7..9655d35e 100644 --- a/src/crypto/randomx/randomx.h +++ b/src/crypto/randomx/randomx.h @@ -29,8 +29,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef RANDOMX_H #define RANDOMX_H -#include -#include +#include +#include #include #include "crypto/randomx/intrin_portable.h" @@ -41,17 +41,20 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define RANDOMX_EXPORT #endif -typedef enum { - RANDOMX_FLAG_DEFAULT = 0, - RANDOMX_FLAG_LARGE_PAGES = 1, - RANDOMX_FLAG_HARD_AES = 2, - RANDOMX_FLAG_FULL_MEM = 4, - RANDOMX_FLAG_JIT = 8, -} randomx_flags; -typedef struct randomx_dataset randomx_dataset; -typedef struct randomx_cache randomx_cache; -typedef struct randomx_vm randomx_vm; +enum randomx_flags { + RANDOMX_FLAG_DEFAULT = 0, + RANDOMX_FLAG_LARGE_PAGES = 1, + RANDOMX_FLAG_HARD_AES = 2, + RANDOMX_FLAG_FULL_MEM = 4, + RANDOMX_FLAG_JIT = 8, +}; + + +struct randomx_dataset; +struct randomx_cache; +class randomx_vm; + struct RandomX_ConfigurationBase { @@ -130,6 +133,14 @@ struct RandomX_ConfigurationBase uint32_t ConditionMask_Calculated; +#if defined(XMRIG_ARMv8) + uint32_t Log2_ScratchpadL1; + uint32_t Log2_ScratchpadL2; + uint32_t Log2_ScratchpadL3; + uint32_t Log2_DatasetBaseSize; + uint32_t Log2_CacheSize; +#endif + int CEIL_IADD_RS; int CEIL_IADD_M; int CEIL_ISUB_R; @@ -166,11 +177,13 @@ struct RandomX_ConfigurationMonero : public RandomX_ConfigurationBase {}; struct RandomX_ConfigurationWownero : public RandomX_ConfigurationBase { RandomX_ConfigurationWownero(); }; struct RandomX_ConfigurationLoki : public RandomX_ConfigurationBase { RandomX_ConfigurationLoki(); }; struct RandomX_ConfigurationArqma : public RandomX_ConfigurationBase { RandomX_ConfigurationArqma(); }; +struct RandomX_ConfigurationSafex : public RandomX_ConfigurationBase { RandomX_ConfigurationSafex(); }; extern RandomX_ConfigurationMonero RandomX_MoneroConfig; extern RandomX_ConfigurationWownero RandomX_WowneroConfig; extern RandomX_ConfigurationLoki RandomX_LokiConfig; extern RandomX_ConfigurationArqma RandomX_ArqmaConfig; +extern RandomX_ConfigurationSafex RandomX_SafexConfig; extern RandomX_ConfigurationBase RandomX_CurrentConfig; @@ -327,8 +340,11 @@ RANDOMX_EXPORT void randomx_destroy_vm(randomx_vm *machine); */ RANDOMX_EXPORT void randomx_calculate_hash(randomx_vm *machine, const void *input, size_t inputSize, void *output); +RANDOMX_EXPORT void randomx_calculate_hash_first(randomx_vm* machine, uint64_t (&tempHash)[8], const void* input, size_t inputSize); +RANDOMX_EXPORT void randomx_calculate_hash_next(randomx_vm* machine, uint64_t (&tempHash)[8], const void* nextInput, size_t nextInputSize, void* output); + #if defined(__cplusplus) } #endif -#endif +#endif \ No newline at end of file diff --git a/src/crypto/randomx/virtual_machine.cpp b/src/crypto/randomx/virtual_machine.cpp index 2913c7e5..8e937902 100644 --- a/src/crypto/randomx/virtual_machine.cpp +++ b/src/crypto/randomx/virtual_machine.cpp @@ -111,7 +111,13 @@ namespace randomx { template void VmBase::getFinalResult(void* out, size_t outSize) { hashAes1Rx4(scratchpad, ScratchpadSize, ®.a); - rx_blake2b(out, outSize, ®, sizeof(RegisterFile), nullptr, 0); + rx_blake2b(out, outSize, ®, sizeof(RegisterFile), nullptr, 0); + } + + template + void VmBase::hashAndFill(void* out, size_t outSize, uint64_t (&fill_state)[8]) { + hashAndFillAes1Rx4(scratchpad, ScratchpadSize, ®.a, fill_state); + rx_blake2b(out, outSize, ®, sizeof(RegisterFile), nullptr, 0); } template @@ -126,4 +132,4 @@ namespace randomx { template class VmBase; template class VmBase; -} +} \ No newline at end of file diff --git a/src/crypto/randomx/virtual_machine.hpp b/src/crypto/randomx/virtual_machine.hpp index 2dc89bb5..5ef72aa6 100644 --- a/src/crypto/randomx/virtual_machine.hpp +++ b/src/crypto/randomx/virtual_machine.hpp @@ -39,6 +39,7 @@ public: virtual ~randomx_vm() = 0; virtual void setScratchpad(uint8_t *scratchpad) = 0; virtual void getFinalResult(void* out, size_t outSize) = 0; + virtual void hashAndFill(void* out, size_t outSize, uint64_t (&fill_state)[8]) = 0; virtual void setDataset(randomx_dataset* dataset) { } virtual void setCache(randomx_cache* cache) { } virtual void initScratchpad(void* seed) = 0; @@ -64,7 +65,7 @@ protected: alignas(64) randomx::RegisterFile reg; alignas(16) randomx::ProgramConfiguration config; randomx::MemoryRegisters mem; - uint8_t* scratchpad; + uint8_t* scratchpad = nullptr; union { randomx_cache* cachePtr = nullptr; randomx_dataset* datasetPtr; @@ -82,9 +83,10 @@ namespace randomx { void setScratchpad(uint8_t *scratchpad) override; void initScratchpad(void* seed) override; void getFinalResult(void* out, size_t outSize) override; + void hashAndFill(void* out, size_t outSize, uint64_t (&fill_state)[8]) override; protected: void generateProgram(void* seed); }; -} +} \ No newline at end of file diff --git a/src/crypto/rx/RxAlgo.cpp b/src/crypto/rx/RxAlgo.cpp index 3cf97c55..1a36a532 100644 --- a/src/crypto/rx/RxAlgo.cpp +++ b/src/crypto/rx/RxAlgo.cpp @@ -49,6 +49,9 @@ const RandomX_ConfigurationBase *xmrig::RxAlgo::base(Algorithm::Id algorithm) case Algorithm::RX_ARQ: return &RandomX_ArqmaConfig; + case Algorithm::RX_SFX: + return &RandomX_SafexConfig; + default: break; } diff --git a/src/version.h b/src/version.h index c9f5a966..2191636e 100644 --- a/src/version.h +++ b/src/version.h @@ -28,7 +28,7 @@ #define APP_ID "XMRigCC" #define APP_NAME "XMRigCC" #define APP_DESC "XMRigCC CPU miner" -#define APP_VERSION "2.2.0" +#define APP_VERSION "2.2.1" #define APP_DOMAIN "" #define APP_SITE "https://github.com/BenDr0id/xmrigCC/" #define APP_COPYRIGHT "Copyright (C) 2017- XMRigCC" @@ -36,7 +36,7 @@ #define APP_VER_MAJOR 2 #define APP_VER_MINOR 2 -#define APP_VER_PATCH 0 +#define APP_VER_PATCH 1 #ifndef NDEBUG #define BUILD_TYPE "DEBUG"