diff --git a/CMakeLists.txt b/CMakeLists.txt index 79abee5..7b84937 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,23 +3,21 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) # set(CMAKE_BUILD_TYPE DEBUG) set(CMAKE_BUILD_TYPE RELEASE) -add_compile_options(-Wno-return-type) project(cgproxy VERSION 3.7) find_package(Threads REQUIRED) -find_package(libconfig++ REQUIRED) find_package(nlohmann_json REQUIRED) include_directories(${PROJECT_SOURCE_DIR}) add_executable(cgattach cgattach.cpp) add_executable(main main.cpp) -target_link_libraries(main Threads::Threads ${LIBCONFIG++_LIBRARIES} nlohmann_json::nlohmann_json) +target_link_libraries(main Threads::Threads nlohmann_json::nlohmann_json) -add_executable(client socket_client.cpp) -target_link_libraries(client nlohmann_json::nlohmann_json) +add_executable(client_test socket_client_test.cpp) +target_link_libraries(client_test nlohmann_json::nlohmann_json) install(TARGETS cgattach DESTINATION /usr/bin @@ -30,14 +28,11 @@ install(FILES cgproxy.sh DESTINATION /usr/bin install(FILES cgnoproxy.sh DESTINATION /usr/bin RENAME cgnoproxy PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) -# install(FILES run_in_cgroup.sh DESTINATION /usr/bin -# RENAME run_in_cgroup -# PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE) install(FILES cgproxy.service DESTINATION /usr/lib/systemd/system/) -install(FILES cgproxy.conf - DESTINATION /etc/) +install(FILES cgproxy.json + DESTINATION /etc/cgproxy/) install(FILES cgroup-tproxy.sh DESTINATION /usr/share/cgproxy/scripts/) diff --git a/cgattach.cpp b/cgattach.cpp index 7db6f98..ee5844e 100644 --- a/cgattach.cpp +++ b/cgattach.cpp @@ -1,98 +1,23 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cgroup_attach.h" using namespace std; void print_usage() { fprintf(stdout, "usage: cgattach \n"); } -bool exist(string path) { - struct stat st; - if (stat(path.c_str(), &st) != -1) { - return S_ISDIR(st.st_mode); - } - return false; -} - -bool validate(string pid, string cgroup) { - bool pid_v = regex_match(pid, regex("^[0-9]+$")); - bool cg_v = regex_match(cgroup, regex("^\\/[a-zA-Z0-9\\-_./@]*$")); - if (pid_v && cg_v) - return true; - - fprintf(stderr, "paramater validate error\n"); - print_usage(); - exit(EXIT_FAILURE); -} - -string get_cgroup2_mount_point(){ - char cgroup2_mount_point[100]=""; - FILE* fp = popen("findmnt -t cgroup2 -n -o TARGET", "r"); - int count=fscanf(fp,"%s",&cgroup2_mount_point); - fclose(fp); - if (count=0){ - fprintf(stderr, "cgroup2 not supported\n"); - exit(EXIT_FAILURE); - } - return cgroup2_mount_point; -} - int main(int argc, char *argv[]) { - setuid(0); - setgid(0); - if (getuid() != 0 || getgid() != 0) { - fprintf(stderr, "cgattach need suid sticky bit or run with root\n"); + int flag=setuid(0); + if (flag!=0) { + perror("cgattach setuid"); exit(EXIT_FAILURE); } if (argc != 3) { - fprintf(stderr, "only need 2 paramaters\n"); + error("only need 2 paramaters"); print_usage(); exit(EXIT_FAILURE); } string pid = string(argv[1]); string cgroup_target = string(argv[2]); - validate(pid, cgroup_target); - // string cgroup_mount_point = "/sys/fs/cgroup"; - string cgroup_mount_point = get_cgroup2_mount_point(); - string cgroup_target_path = cgroup_mount_point + cgroup_target; - string cgroup_target_procs = cgroup_target_path + "/cgroup.procs"; - // check if exist, we will create it if not exist - if (!exist(cgroup_target_path)) { - if (mkdir(cgroup_target_path.c_str(), - S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH) == 0) { - fprintf(stdout, "created cgroup %s success\n", cgroup_target.c_str()); - } else { - fprintf(stderr, "created cgroup %s failed, errno %d\n", - cgroup_target.c_str(), errno); - exit(EXIT_FAILURE); - } - // fprintf(stderr, "cgroup %s not exist\n",cgroup_target.c_str()); - // exit(EXIT_FAILURE); - } - - // put pid to target cgroup - ofstream procs(cgroup_target_procs, ofstream::app); - if (!procs.is_open()) { - fprintf(stderr, "open file %s failed\n", cgroup_target_procs.c_str()); - exit(EXIT_FAILURE); - } - procs << pid.c_str() << endl; - procs.close(); - - // maybe there some write error, for example process pid may not exist - if (!procs) { - fprintf(stderr, "write %s to %s failed, maybe process %s not exist\n", - pid.c_str(), cgroup_target_procs.c_str(), pid.c_str()); - exit(EXIT_FAILURE); - } - return EXIT_SUCCESS; + CGPROXY::CGROUP::attach(pid,cgroup_target); } diff --git a/common.h b/common.h index 7089b33..22a66cf 100644 --- a/common.h +++ b/common.h @@ -2,33 +2,77 @@ #define COMMON_H #define SOCKET_PATH "/tmp/unix_socket" -#define LISTEN_BACKLOG 5 +#define LISTEN_BACKLOG 64 #define DEFAULT_CONFIG_FILE "/etc/cgproxy.conf" +#define CGROUP_PROXY_PRESVERED "/proxy.slice" +#define CGROUP_NOPROXY_PRESVERED "/noproxy.slice" + #define MSG_TYPE_JSON 1 #define MSG_TYPE_CONFIG_PATH 2 #define MSG_TYPE_PROXY_PID 3 #define MSG_TYPE_NOPROXY_PID 4 -#define UNKNOWN_ERROR -99 -#define CONN_ERROR -1 -#define MSG_ERROR 1 -#define PARSE_ERROR 2 -#define PARAM_ERROR 3 -#define APPLY_ERROR 4 +#define UNKNOWN_ERROR 99 +#define ERROR -1 +#define SUCCESS 0 +#define CONN_ERROR 1 +#define MSG_ERROR 2 +#define PARSE_ERROR 3 +#define PARAM_ERROR 4 +#define APPLY_ERROR 5 +#define CGROUP_ERROR 6 +#define FILE_ERROR 7 + #include #include -#include +#include +#include using namespace std; -template string to_str(T... args) { + +#define error(...) {fprintf(stderr, __VA_ARGS__);fprintf(stderr, "\n");} +#define debug(...) {fprintf(stdout, __VA_ARGS__);fprintf(stdout, "\n");} +#define return_error return -1; +#define return_success return 0; + + +template +string to_str(T... args) { stringstream ss; ss.clear(); (ss << ... << args) << endl; return ss.str(); } -#define error(...) {fprintf(stderr, __VA_ARGS__);fprintf(stderr, "\n");} -#define debug(...) {fprintf(stdout, __VA_ARGS__);fprintf(stdout, "\n");} +template +string join2str(const T t){ + string s; + string delm=" ", prefix="(", tail=")", wrap="\""; + for (const auto &e : t) + e!=*(t.end()-1)?s+=wrap+e+wrap+delm:s+=wrap+e+wrap; + return prefix+s+tail; +} + +bool validCgroup(const string cgroup){ + return regex_match(cgroup, regex("^/[a-zA-Z0-9\\-_./@]*$")); +} + +bool validCgroup(const vector cgroup){ + for (auto &e:cgroup){ + if (!regex_match(e, regex("^/[a-zA-Z0-9\\-_./@]*$"))){ + return false; + } + } + return true; +} + +bool validPid(const string pid){ + return regex_match(pid, regex("^[0-9]+$")); +} + +bool validPort(const string port){ + return regex_match(port, regex("^[0-9]+$")); +} #endif \ No newline at end of file diff --git a/config.h b/config.h new file mode 100644 index 0000000..9f00a5d --- /dev/null +++ b/config.h @@ -0,0 +1,142 @@ +#ifndef CONFIG_H +#define CONFIG_H +#include "common.h" +#include "socket_server.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace std; +using json = nlohmann::json; + +namespace CGPROXY::CONFIG{ + +struct Config { + public: + const string cgroup_proxy_preserved=CGROUP_PROXY_PRESVERED; + const string cgroup_noproxy_preserved=CGROUP_NOPROXY_PRESVERED; + private: + vector cgroup_proxy; + vector cgroup_noproxy; + bool enable_gateway = false; + int port = 12345; + bool enable_dns = true; + bool enable_tcp = true; + bool enable_udp = true; + bool enable_ipv4 = true; + bool enable_ipv6 = true; + +public: + void toEnv() { + mergeReserved(); + setenv("cgroup_proxy", join2str(cgroup_proxy).c_str(), 1); + setenv("cgroup_noproxy", join2str(cgroup_noproxy).c_str(), 1); + setenv("enable_gateway", to_str(enable_gateway).c_str(), 1); + setenv("port", to_str(port).c_str(), 1); + setenv("enable_dns", to_str(enable_dns).c_str(), 1); + setenv("enable_tcp", to_str(enable_tcp).c_str(), 1); + setenv("enable_udp", to_str(enable_udp).c_str(), 1); + setenv("enable_ipv4", to_str(enable_ipv4).c_str(), 1); + setenv("enable_ipv6", to_str(enable_ipv6).c_str(), 1); + } + + int saveToFile(const string f){ + ofstream o(f); + if (!o.is_open()) return FILE_ERROR; + json j=toJson(); + o << setw(4) << j << endl; + o.close(); + return 0; + } + + json toJson(){ + json j; + #define add2json(v) j[#v]=v; + add2json(cgroup_proxy); + add2json(cgroup_noproxy); + add2json(enable_gateway); + add2json(port); + add2json(enable_dns); + add2json(enable_tcp); + add2json(enable_udp); + add2json(enable_ipv4); + add2json(enable_ipv6); + #undef add2json + return j; + } + + int loadFromFile(const string f) { + debug("loading config: %s", f.c_str()); + ifstream ifs(f); + if (ifs.is_open()){ + json j; + try { ifs >> j; }catch (exception& e){error("parse error: %s", f.c_str());ifs.close();return PARSE_ERROR;} + ifs.close(); + return loadFromJson(j); + }else{ + error("open failed: %s",f.c_str()); + return FILE_ERROR; + } + } + + int loadFromJson(const json &j) { + if (!validateJson(j)) {error("json validate fail"); return PARAM_ERROR;} + #define tryassign(v) try{j.at(#v).get_to(v);}catch(exception &e){} + tryassign(cgroup_proxy); + tryassign(cgroup_noproxy); + tryassign(enable_gateway); + tryassign(port); + tryassign(enable_dns); + tryassign(enable_tcp); + tryassign(enable_udp); + tryassign(enable_ipv4); + tryassign(enable_ipv6); + #undef assign + return 0; + } + + void mergeReserved(){ + #define merge(v) { \ + v.erase(std::remove(v.begin(), v.end(), v ## _preserved), v.end()); \ + v.insert(v.begin(), v ## _preserved); \ + } + merge(cgroup_proxy); + merge(cgroup_noproxy); + #undef merge + + } + + bool validateJson(const json &j){ + bool status=true; + const set boolset={"enable_gateway","enable_dns","enable_tcp","enable_udp","enable_ipv4","enable_ipv6"}; + for (auto& [key, value] : j.items()) { + if (key=="cgroup_proxy"||key=="cgroup_noproxy"){ + if (value.is_string()&&!validCgroup((string)value)) status=false; + if (value.is_array()&&!validCgroup((vector)value)) status=false; + if (!value.is_string()&&!value.is_array()) status=false; + }else if (key=="port"){ + if (validPort(value)) status=false; + }else if (boolset.find(key)!=boolset.end()){ + if (value.is_boolean()) status=false; + }else{ + error("unknown key: %s", key.c_str()); + return false; + } + if (!status) { + error("invalid value for key: %s", key.c_str()); + return false; + } + } + return true; + } +}; + +} +#endif \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..383d3c1 --- /dev/null +++ b/config.json @@ -0,0 +1,11 @@ +{ + "cgroup_noproxy": [], + "cgroup_proxy": [], + "enable_dns": true, + "enable_gateway": false, + "enable_ipv4": true, + "enable_ipv6": true, + "enable_tcp": true, + "enable_udp": true, + "port": 12345 +} diff --git a/main.cpp b/main.cpp index a02846a..838874d 100644 --- a/main.cpp +++ b/main.cpp @@ -8,103 +8,19 @@ #include #include #include -#include +#include +#include "config.h" +#include "cgroup_attach.h" using namespace std; using json = nlohmann::json; +using namespace CGPROXY::SOCKET; +using namespace CGPROXY::CONFIG; +using namespace CGPROXY::CGROUP; -struct Config { - string cgroup_proxy = "/proxy.slice"; - string cgroup_noproxy = "/noproxy.slice"; - bool enable_gateway = false; - int port = 12345; - bool enable_dns = true; - bool enable_tcp = true; - bool enable_udp = true; - bool enable_ipv4 = true; - bool enable_ipv6 = true; - - void toEnv() { - #define env(v) setenv(#v, to_str(v).c_str(), 1) - env(cgroup_proxy); - env(cgroup_noproxy); - env(enable_gateway); - env(port); - env(enable_dns); - env(enable_tcp); - env(enable_udp); - env(enable_ipv4); - env(enable_ipv6); - #undef env - } - - int safeLoadFromFile(const string path){ - Config tmp=*this; - int flag=tmp.loadFromFile(path); - if (flag!=0) return flag; - if (tmp.isValid()){ - loadFromFile(path); - return 0; - }else{ - return PARAM_ERROR; - } - } - - int safeLoadFromJson(const json& j){ - Config tmp=*this; - int flag=tmp.loadFromJson(j); - if (flag!=0) return flag; - if (tmp.isValid()){ - loadFromJson(j); - return 0; - }else{ - return PARAM_ERROR; - } - } - - private: - int loadFromFile(const string f) { - debug("loading config: %s", f.c_str()); - libconfig::Config config_f; - try { config_f.readFile(f.c_str()); } catch (exception &e) { return PARSE_ERROR; } - #define assign(v, t) if (config_f.exists(#v)) {v = (t)config_f.lookup(#v);} - assign(cgroup_proxy, string); - assign(cgroup_noproxy, string); - assign(enable_gateway, bool); - assign(port, int); - assign(enable_dns, bool); - assign(enable_tcp, bool); - assign(enable_udp, bool); - assign(enable_ipv4, bool); - assign(enable_ipv6, bool); - #undef assign - return 0; - } - - int loadFromJson(const json &j) { - #define get_to(v) try {j.at(#v).get_to(v); } catch (exception& e) {} - get_to(cgroup_proxy); - get_to(cgroup_noproxy); - get_to(enable_gateway); - get_to(port); - get_to(enable_dns); - get_to(enable_tcp); - get_to(enable_udp); - get_to(enable_ipv4); - get_to(enable_ipv6); - #undef get_to - return 0; - } - - bool isValid(){ - // TODO - return true; - } -}; - -SocketControl sc; +SocketServer sc; thread_arg arg_t; -Config config_tproxy; +Config config; pthread_t socket_thread_id = -1; int applyConfig(Config *c) { @@ -118,21 +34,42 @@ int handle_msg(char *msg) { debug("received msg: %s", msg); json j; try{ j = json::parse(msg); }catch(exception& e){debug("msg paser error");return MSG_ERROR;} - int type = -1, status; + + int type, status; + string pid, cgroup_target; try { type = j.at("type").get(); - if (type == MSG_TYPE_JSON) { // json data - status=config_tproxy.safeLoadFromJson(j.at("data")); - } else if (type == MSG_TYPE_CONFIG_PATH) { // config file - status=config_tproxy.safeLoadFromFile(j.at("data").get()); - } + switch (type) + { + case MSG_TYPE_JSON: + status=config.loadFromJson(j.at("data")); + if (status==SUCCESS) status=applyConfig(&config); + return status; + break; + case MSG_TYPE_CONFIG_PATH: + status=config.loadFromFile(j.at("data").get()); + if (status==SUCCESS) status=applyConfig(&config); + return status; + break; + case MSG_TYPE_PROXY_PID: + pid=j.at("data").get(); + status=attach(pid, config.cgroup_proxy_preserved); + return status; + break; + case MSG_TYPE_NOPROXY_PID: + pid=j.at("data").get(); + status=attach(pid, config.cgroup_noproxy_preserved); + return status; + break; + default: + return MSG_ERROR; + break; + }; } catch (out_of_range &e) { return MSG_ERROR; + } catch (exception &e){ + return ERROR; } - if (status==0){ - return applyConfig(&config_tproxy); - } - return status; } pthread_t startSocketListeningThread() { @@ -140,7 +77,7 @@ pthread_t startSocketListeningThread() { arg_t.handle_msg = &handle_msg; pthread_t thread_id; int status = - pthread_create(&thread_id, NULL, &SocketControl::startThread, &arg_t); + pthread_create(&thread_id, NULL, &SocketServer::startThread, &arg_t); if (status != 0) error("socket thread create failed"); return thread_id; @@ -149,13 +86,11 @@ pthread_t startSocketListeningThread() { int main() { bool enable_socket = true; string config_path = DEFAULT_CONFIG_FILE; - config_tproxy.safeLoadFromFile(config_path); - applyConfig(&config_tproxy); + config.loadFromFile(config_path); + applyConfig(&config); if (enable_socket) { socket_thread_id = startSocketListeningThread(); pthread_join(socket_thread_id, NULL); } return 0; -} - -// TODO handle attch pid \ No newline at end of file +} \ No newline at end of file diff --git a/socket_client.cpp b/socket_client.h similarity index 74% rename from socket_client.cpp rename to socket_client.h index 7301ba9..143ff55 100644 --- a/socket_client.cpp +++ b/socket_client.h @@ -1,25 +1,29 @@ -#include "common.h" +#ifndef SOCKET_CLIENT_H +#define SOCKET_CLIENT_H + #include -#include #include #include -#include +#include #include #include #include #include +#include "common.h" using namespace std; -using json = nlohmann::json; + +namespace CGPROXY::SOCKET{ #define return_if_error(flag, msg) \ if (flag == -1) { \ perror(msg); \ status = CONN_ERROR; \ + close(sfd); \ return; \ } -void send(char *msg, int &status) { +void send(const char *msg, int &status) { debug("send msg: %s", msg); status = UNKNOWN_ERROR; @@ -47,8 +51,7 @@ void send(char *msg, int &status) { close(sfd); } -void send(const json &j, int &status) { - string msg = j.dump(); +void send(const string msg, int &status) { int msg_len = msg.length(); char buff[msg_len]; msg.copy(buff, msg_len, 0); @@ -57,24 +60,5 @@ void send(const json &j, int &status) { debug("return status: %d", status); } -int test_json() { - json j; - j["type"] = MSG_TYPE_JSON; - j["data"]["cgroup_proxy"] = "/"; - j["data"]["enable_dns"] = false; - int status; - send(j, status); } - -void test_file() { - json j; - j["type"] = MSG_TYPE_CONFIG_PATH; - j["data"] = "/etc/cgproxy.conf"; - int status; - send(j, status); -} - -int main() { - test_file(); - test_json(); -} \ No newline at end of file +#endif \ No newline at end of file diff --git a/socket_server.h b/socket_server.h index 2980251..dc08593 100644 --- a/socket_server.h +++ b/socket_server.h @@ -1,25 +1,21 @@ #ifndef SOCKET_SERVER_H #define SOCKET_SERVER_H -#include "common.h" #include #include #include #include #include -#include +#include #include #include #include #include #include +#include "common.h" using namespace std; -#define SOCKET_PATH "/tmp/unix_socket" -#define LISTEN_BACKLOG 5 - -class SocketControl; -struct thread_arg; +namespace CGPROXY::SOCKET{ #define continue_if_error(flag, msg) \ if (flag == -1) { \ @@ -27,12 +23,13 @@ struct thread_arg; continue; \ } +class SocketServer; struct thread_arg { - SocketControl *sc; + SocketServer *sc; function handle_msg; }; -class SocketControl { +class SocketServer { public: int sfd = -1, cfd = -1, flag = -1; struct sockaddr_un unix_socket; @@ -41,7 +38,8 @@ public: debug("starting socket listening"); sfd = socket(AF_UNIX, SOCK_STREAM, 0); - unlink(SOCKET_PATH); + flag=unlink(SOCKET_PATH); + if (flag==-1) {error("%s exist, and can't unlink",SOCKET_PATH); exit(EXIT_FAILURE);} memset(&unix_socket, '\0', sizeof(struct sockaddr_un)); unix_socket.sun_family = AF_UNIX; strncpy(unix_socket.sun_path, SOCKET_PATH, @@ -75,7 +73,7 @@ public: } } - ~SocketControl() { + ~SocketServer() { close(sfd); close(cfd); unlink(SOCKET_PATH); @@ -84,7 +82,10 @@ public: static void *startThread(void *arg) { thread_arg *p = (thread_arg *)arg; p->sc->socketListening(p->handle_msg); + return (void *)0; } }; +} + #endif \ No newline at end of file