From 4e37bccc1aea1f3b941df90b0c7e094638f700f7 Mon Sep 17 00:00:00 2001 From: fancy Date: Wed, 13 May 2020 23:43:07 +0800 Subject: [PATCH] add basic unix socket control --- CMakeLists.txt | 20 ++++++ common.h | 34 ++++++++++ main.cpp | 161 ++++++++++++++++++++++++++++++++++++++++++++++ socket_client.cpp | 80 +++++++++++++++++++++++ socket_server.h | 90 ++++++++++++++++++++++++++ 5 files changed, 385 insertions(+) create mode 100644 common.h create mode 100644 main.cpp create mode 100644 socket_client.cpp create mode 100644 socket_server.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f239dae..38d220c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,26 @@ cmake_minimum_required(VERSION 3.10) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_BUILD_TYPE DEBUG) +add_compile_options(-Wno-return-type) project(cgproxy VERSION 3.7) + + +find_package(Threads REQUIRED) +find_package(libconfig++ REQUIRED) +find_package(nlohmann_json REQUIRED) +include_directories(${PROJECT_SOURCE_DIR}) + add_executable(cgattach cgattach.cpp) +add_executable(main main.cpp) +target_link_libraries(main Threads::Threads ${LIBCONFIG++_LIBRARIES} nlohmann_json::nlohmann_json) + +add_executable(client socket_client.cpp) +target_link_libraries(client nlohmann_json::nlohmann_json) + + install(TARGETS cgattach DESTINATION /usr/bin PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE SETUID) install(FILES cgproxy.sh DESTINATION /usr/bin @@ -26,6 +44,8 @@ install(FILES readme.md DESTINATION /share/doc/cgproxy/) + + ## package for deb and rpm set(CPACK_GENERATOR "DEB;RPM") set(CPACK_PACKAGE_NAME "cgproxy") diff --git a/common.h b/common.h new file mode 100644 index 0000000..7089b33 --- /dev/null +++ b/common.h @@ -0,0 +1,34 @@ +#ifndef COMMON_H +#define COMMON_H + +#define SOCKET_PATH "/tmp/unix_socket" +#define LISTEN_BACKLOG 5 +#define DEFAULT_CONFIG_FILE "/etc/cgproxy.conf" + +#define MSG_TYPE_JSON 1 +#define MSG_TYPE_CONFIG_PATH 2 +#define MSG_TYPE_PROXY_PID 3 +#define MSG_TYPE_NOPROXY_PID 4 + +#define UNKNOWN_ERROR -99 +#define CONN_ERROR -1 +#define MSG_ERROR 1 +#define PARSE_ERROR 2 +#define PARAM_ERROR 3 +#define APPLY_ERROR 4 + +#include +#include +#include +using namespace std; +template string to_str(T... args) { + stringstream ss; + ss.clear(); + (ss << ... << args) << endl; + return ss.str(); +} + +#define error(...) {fprintf(stderr, __VA_ARGS__);fprintf(stderr, "\n");} +#define debug(...) {fprintf(stdout, __VA_ARGS__);fprintf(stdout, "\n");} + +#endif \ No newline at end of file diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..a02846a --- /dev/null +++ b/main.cpp @@ -0,0 +1,161 @@ +#include "common.h" +#include "socket_server.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using json = nlohmann::json; + +struct Config { + string cgroup_proxy = "/proxy.slice"; + string cgroup_noproxy = "/noproxy.slice"; + bool enable_gateway = false; + int port = 12345; + bool enable_dns = true; + bool enable_tcp = true; + bool enable_udp = true; + bool enable_ipv4 = true; + bool enable_ipv6 = true; + + void toEnv() { + #define env(v) setenv(#v, to_str(v).c_str(), 1) + env(cgroup_proxy); + env(cgroup_noproxy); + env(enable_gateway); + env(port); + env(enable_dns); + env(enable_tcp); + env(enable_udp); + env(enable_ipv4); + env(enable_ipv6); + #undef env + } + + int safeLoadFromFile(const string path){ + Config tmp=*this; + int flag=tmp.loadFromFile(path); + if (flag!=0) return flag; + if (tmp.isValid()){ + loadFromFile(path); + return 0; + }else{ + return PARAM_ERROR; + } + } + + int safeLoadFromJson(const json& j){ + Config tmp=*this; + int flag=tmp.loadFromJson(j); + if (flag!=0) return flag; + if (tmp.isValid()){ + loadFromJson(j); + return 0; + }else{ + return PARAM_ERROR; + } + } + + private: + int loadFromFile(const string f) { + debug("loading config: %s", f.c_str()); + libconfig::Config config_f; + try { config_f.readFile(f.c_str()); } catch (exception &e) { return PARSE_ERROR; } + #define assign(v, t) if (config_f.exists(#v)) {v = (t)config_f.lookup(#v);} + assign(cgroup_proxy, string); + assign(cgroup_noproxy, string); + assign(enable_gateway, bool); + assign(port, int); + assign(enable_dns, bool); + assign(enable_tcp, bool); + assign(enable_udp, bool); + assign(enable_ipv4, bool); + assign(enable_ipv6, bool); + #undef assign + return 0; + } + + int loadFromJson(const json &j) { + #define get_to(v) try {j.at(#v).get_to(v); } catch (exception& e) {} + get_to(cgroup_proxy); + get_to(cgroup_noproxy); + get_to(enable_gateway); + get_to(port); + get_to(enable_dns); + get_to(enable_tcp); + get_to(enable_udp); + get_to(enable_ipv4); + get_to(enable_ipv6); + #undef get_to + return 0; + } + + bool isValid(){ + // TODO + return true; + } +}; + +SocketControl sc; +thread_arg arg_t; +Config config_tproxy; +pthread_t socket_thread_id = -1; + +int applyConfig(Config *c) { + system("sh /usr/share/cgproxy/scripts/cgroup-tproxy.sh stop"); + c->toEnv(); + system("sh /usr/share/cgproxy/scripts/cgroup-tproxy.sh"); + return 0; +} + +int handle_msg(char *msg) { + debug("received msg: %s", msg); + json j; + try{ j = json::parse(msg); }catch(exception& e){debug("msg paser error");return MSG_ERROR;} + int type = -1, status; + try { + type = j.at("type").get(); + if (type == MSG_TYPE_JSON) { // json data + status=config_tproxy.safeLoadFromJson(j.at("data")); + } else if (type == MSG_TYPE_CONFIG_PATH) { // config file + status=config_tproxy.safeLoadFromFile(j.at("data").get()); + } + } catch (out_of_range &e) { + return MSG_ERROR; + } + if (status==0){ + return applyConfig(&config_tproxy); + } + return status; +} + +pthread_t startSocketListeningThread() { + arg_t.sc = ≻ + arg_t.handle_msg = &handle_msg; + pthread_t thread_id; + int status = + pthread_create(&thread_id, NULL, &SocketControl::startThread, &arg_t); + if (status != 0) + error("socket thread create failed"); + return thread_id; +} + +int main() { + bool enable_socket = true; + string config_path = DEFAULT_CONFIG_FILE; + config_tproxy.safeLoadFromFile(config_path); + applyConfig(&config_tproxy); + if (enable_socket) { + socket_thread_id = startSocketListeningThread(); + pthread_join(socket_thread_id, NULL); + } + return 0; +} + +// TODO handle attch pid \ No newline at end of file diff --git a/socket_client.cpp b/socket_client.cpp new file mode 100644 index 0000000..7301ba9 --- /dev/null +++ b/socket_client.cpp @@ -0,0 +1,80 @@ +#include "common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using json = nlohmann::json; + +#define return_if_error(flag, msg) \ + if (flag == -1) { \ + perror(msg); \ + status = CONN_ERROR; \ + return; \ + } + +void send(char *msg, int &status) { + debug("send msg: %s", msg); + status = UNKNOWN_ERROR; + + int flag; + int sfd = socket(AF_UNIX, SOCK_STREAM, 0); + + struct sockaddr_un unix_socket; + memset(&unix_socket, '\0', sizeof(struct sockaddr_un)); + unix_socket.sun_family = AF_UNIX; + strncpy(unix_socket.sun_path, SOCKET_PATH, sizeof(unix_socket.sun_path) - 1); + + flag = + connect(sfd, (struct sockaddr *)&unix_socket, sizeof(struct sockaddr_un)); + return_if_error(flag, "connect"); + + int msg_len = strlen(msg); + flag = write(sfd, &msg_len, sizeof(int)); + return_if_error(flag, "write length"); + flag = write(sfd, msg, msg_len * sizeof(char)); + return_if_error(flag, "write msg"); + + flag = read(sfd, &status, sizeof(int)); + return_if_error(flag, "read return value"); + + close(sfd); +} + +void send(const json &j, int &status) { + string msg = j.dump(); + int msg_len = msg.length(); + char buff[msg_len]; + msg.copy(buff, msg_len, 0); + buff[msg_len] = '\0'; + send(buff, status); + debug("return status: %d", status); +} + +int test_json() { + json j; + j["type"] = MSG_TYPE_JSON; + j["data"]["cgroup_proxy"] = "/"; + j["data"]["enable_dns"] = false; + int status; + send(j, status); +} + +void test_file() { + json j; + j["type"] = MSG_TYPE_CONFIG_PATH; + j["data"] = "/etc/cgproxy.conf"; + int status; + send(j, status); +} + +int main() { + test_file(); + test_json(); +} \ No newline at end of file diff --git a/socket_server.h b/socket_server.h new file mode 100644 index 0000000..2980251 --- /dev/null +++ b/socket_server.h @@ -0,0 +1,90 @@ +#ifndef SOCKET_SERVER_H +#define SOCKET_SERVER_H + +#include "common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; + +#define SOCKET_PATH "/tmp/unix_socket" +#define LISTEN_BACKLOG 5 + +class SocketControl; +struct thread_arg; + +#define continue_if_error(flag, msg) \ + if (flag == -1) { \ + perror(msg); \ + continue; \ + } + +struct thread_arg { + SocketControl *sc; + function handle_msg; +}; + +class SocketControl { +public: + int sfd = -1, cfd = -1, flag = -1; + struct sockaddr_un unix_socket; + + void socketListening(function callback) { + debug("starting socket listening"); + sfd = socket(AF_UNIX, SOCK_STREAM, 0); + + unlink(SOCKET_PATH); + memset(&unix_socket, '\0', sizeof(struct sockaddr_un)); + unix_socket.sun_family = AF_UNIX; + strncpy(unix_socket.sun_path, SOCKET_PATH, + sizeof(unix_socket.sun_path) - 1); + + bind(sfd, (struct sockaddr *)&unix_socket, sizeof(struct sockaddr_un)); + + listen(sfd, LISTEN_BACKLOG); + chmod(SOCKET_PATH,S_IRWXU|S_IRWXG|S_IRWXO); + + while (true) { + close(cfd); + cfd = accept(sfd, NULL, NULL); + continue_if_error(cfd, "accept"); + debug("accept connection: %d", cfd); + + // read length + int msg_len; + flag = read(cfd, &msg_len, sizeof(int)); + continue_if_error(flag, "read length"); + // read msg + char msg[msg_len]; + flag = read(cfd, msg, msg_len * sizeof(char)); + continue_if_error(flag, "read msg"); + msg[msg_len]='\0'; + // handle msg + int status = callback(msg); + // send back flag + flag = write(cfd, &status, sizeof(int)); + continue_if_error(flag, "write back"); + } + } + + ~SocketControl() { + close(sfd); + close(cfd); + unlink(SOCKET_PATH); + } + + static void *startThread(void *arg) { + thread_arg *p = (thread_arg *)arg; + p->sc->socketListening(p->handle_msg); + } +}; + +#endif \ No newline at end of file