add basic unix socket control

This commit is contained in:
fancy
2020-05-13 23:43:07 +08:00
parent f8e0abbb55
commit 4e37bccc1a
5 changed files with 385 additions and 0 deletions

View File

@@ -1,8 +1,26 @@
cmake_minimum_required(VERSION 3.10)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_BUILD_TYPE DEBUG)
add_compile_options(-Wno-return-type)
project(cgproxy VERSION 3.7)
find_package(Threads REQUIRED)
find_package(libconfig++ REQUIRED)
find_package(nlohmann_json REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
add_executable(cgattach cgattach.cpp)
add_executable(main main.cpp)
target_link_libraries(main Threads::Threads ${LIBCONFIG++_LIBRARIES} nlohmann_json::nlohmann_json)
add_executable(client socket_client.cpp)
target_link_libraries(client nlohmann_json::nlohmann_json)
install(TARGETS cgattach DESTINATION /usr/bin
PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE SETUID)
install(FILES cgproxy.sh DESTINATION /usr/bin
@@ -26,6 +44,8 @@ install(FILES readme.md
DESTINATION /share/doc/cgproxy/)
## package for deb and rpm
set(CPACK_GENERATOR "DEB;RPM")
set(CPACK_PACKAGE_NAME "cgproxy")

34
common.h Normal file
View File

@@ -0,0 +1,34 @@
#ifndef COMMON_H
#define COMMON_H
#define SOCKET_PATH "/tmp/unix_socket"
#define LISTEN_BACKLOG 5
#define DEFAULT_CONFIG_FILE "/etc/cgproxy.conf"
#define MSG_TYPE_JSON 1
#define MSG_TYPE_CONFIG_PATH 2
#define MSG_TYPE_PROXY_PID 3
#define MSG_TYPE_NOPROXY_PID 4
#define UNKNOWN_ERROR -99
#define CONN_ERROR -1
#define MSG_ERROR 1
#define PARSE_ERROR 2
#define PARAM_ERROR 3
#define APPLY_ERROR 4
#include <iostream>
#include <sstream>
#include <string.h>
using namespace std;
template <typename... T> string to_str(T... args) {
stringstream ss;
ss.clear();
(ss << ... << args) << endl;
return ss.str();
}
#define error(...) {fprintf(stderr, __VA_ARGS__);fprintf(stderr, "\n");}
#define debug(...) {fprintf(stdout, __VA_ARGS__);fprintf(stdout, "\n");}
#endif

161
main.cpp Normal file
View File

@@ -0,0 +1,161 @@
#include "common.h"
#include "socket_server.h"
#include <fstream>
#include <iostream>
#include <libconfig.h++>
#include <nlohmann/json.hpp>
#include <pthread.h>
#include <sstream>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
using namespace std;
using json = nlohmann::json;
struct Config {
string cgroup_proxy = "/proxy.slice";
string cgroup_noproxy = "/noproxy.slice";
bool enable_gateway = false;
int port = 12345;
bool enable_dns = true;
bool enable_tcp = true;
bool enable_udp = true;
bool enable_ipv4 = true;
bool enable_ipv6 = true;
void toEnv() {
#define env(v) setenv(#v, to_str(v).c_str(), 1)
env(cgroup_proxy);
env(cgroup_noproxy);
env(enable_gateway);
env(port);
env(enable_dns);
env(enable_tcp);
env(enable_udp);
env(enable_ipv4);
env(enable_ipv6);
#undef env
}
int safeLoadFromFile(const string path){
Config tmp=*this;
int flag=tmp.loadFromFile(path);
if (flag!=0) return flag;
if (tmp.isValid()){
loadFromFile(path);
return 0;
}else{
return PARAM_ERROR;
}
}
int safeLoadFromJson(const json& j){
Config tmp=*this;
int flag=tmp.loadFromJson(j);
if (flag!=0) return flag;
if (tmp.isValid()){
loadFromJson(j);
return 0;
}else{
return PARAM_ERROR;
}
}
private:
int loadFromFile(const string f) {
debug("loading config: %s", f.c_str());
libconfig::Config config_f;
try { config_f.readFile(f.c_str()); } catch (exception &e) { return PARSE_ERROR; }
#define assign(v, t) if (config_f.exists(#v)) {v = (t)config_f.lookup(#v);}
assign(cgroup_proxy, string);
assign(cgroup_noproxy, string);
assign(enable_gateway, bool);
assign(port, int);
assign(enable_dns, bool);
assign(enable_tcp, bool);
assign(enable_udp, bool);
assign(enable_ipv4, bool);
assign(enable_ipv6, bool);
#undef assign
return 0;
}
int loadFromJson(const json &j) {
#define get_to(v) try {j.at(#v).get_to(v); } catch (exception& e) {}
get_to(cgroup_proxy);
get_to(cgroup_noproxy);
get_to(enable_gateway);
get_to(port);
get_to(enable_dns);
get_to(enable_tcp);
get_to(enable_udp);
get_to(enable_ipv4);
get_to(enable_ipv6);
#undef get_to
return 0;
}
bool isValid(){
// TODO
return true;
}
};
SocketControl sc;
thread_arg arg_t;
Config config_tproxy;
pthread_t socket_thread_id = -1;
int applyConfig(Config *c) {
system("sh /usr/share/cgproxy/scripts/cgroup-tproxy.sh stop");
c->toEnv();
system("sh /usr/share/cgproxy/scripts/cgroup-tproxy.sh");
return 0;
}
int handle_msg(char *msg) {
debug("received msg: %s", msg);
json j;
try{ j = json::parse(msg); }catch(exception& e){debug("msg paser error");return MSG_ERROR;}
int type = -1, status;
try {
type = j.at("type").get<int>();
if (type == MSG_TYPE_JSON) { // json data
status=config_tproxy.safeLoadFromJson(j.at("data"));
} else if (type == MSG_TYPE_CONFIG_PATH) { // config file
status=config_tproxy.safeLoadFromFile(j.at("data").get<string>());
}
} catch (out_of_range &e) {
return MSG_ERROR;
}
if (status==0){
return applyConfig(&config_tproxy);
}
return status;
}
pthread_t startSocketListeningThread() {
arg_t.sc = &sc;
arg_t.handle_msg = &handle_msg;
pthread_t thread_id;
int status =
pthread_create(&thread_id, NULL, &SocketControl::startThread, &arg_t);
if (status != 0)
error("socket thread create failed");
return thread_id;
}
int main() {
bool enable_socket = true;
string config_path = DEFAULT_CONFIG_FILE;
config_tproxy.safeLoadFromFile(config_path);
applyConfig(&config_tproxy);
if (enable_socket) {
socket_thread_id = startSocketListeningThread();
pthread_join(socket_thread_id, NULL);
}
return 0;
}
// TODO handle attch pid

80
socket_client.cpp Normal file
View File

@@ -0,0 +1,80 @@
#include "common.h"
#include <iostream>
#include <nlohmann/json.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
using namespace std;
using json = nlohmann::json;
#define return_if_error(flag, msg) \
if (flag == -1) { \
perror(msg); \
status = CONN_ERROR; \
return; \
}
void send(char *msg, int &status) {
debug("send msg: %s", msg);
status = UNKNOWN_ERROR;
int flag;
int sfd = socket(AF_UNIX, SOCK_STREAM, 0);
struct sockaddr_un unix_socket;
memset(&unix_socket, '\0', sizeof(struct sockaddr_un));
unix_socket.sun_family = AF_UNIX;
strncpy(unix_socket.sun_path, SOCKET_PATH, sizeof(unix_socket.sun_path) - 1);
flag =
connect(sfd, (struct sockaddr *)&unix_socket, sizeof(struct sockaddr_un));
return_if_error(flag, "connect");
int msg_len = strlen(msg);
flag = write(sfd, &msg_len, sizeof(int));
return_if_error(flag, "write length");
flag = write(sfd, msg, msg_len * sizeof(char));
return_if_error(flag, "write msg");
flag = read(sfd, &status, sizeof(int));
return_if_error(flag, "read return value");
close(sfd);
}
void send(const json &j, int &status) {
string msg = j.dump();
int msg_len = msg.length();
char buff[msg_len];
msg.copy(buff, msg_len, 0);
buff[msg_len] = '\0';
send(buff, status);
debug("return status: %d", status);
}
int test_json() {
json j;
j["type"] = MSG_TYPE_JSON;
j["data"]["cgroup_proxy"] = "/";
j["data"]["enable_dns"] = false;
int status;
send(j, status);
}
void test_file() {
json j;
j["type"] = MSG_TYPE_CONFIG_PATH;
j["data"] = "/etc/cgproxy.conf";
int status;
send(j, status);
}
int main() {
test_file();
test_json();
}

90
socket_server.h Normal file
View File

@@ -0,0 +1,90 @@
#ifndef SOCKET_SERVER_H
#define SOCKET_SERVER_H
#include "common.h"
#include <functional>
#include <iostream>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include <sys/stat.h>
using namespace std;
#define SOCKET_PATH "/tmp/unix_socket"
#define LISTEN_BACKLOG 5
class SocketControl;
struct thread_arg;
#define continue_if_error(flag, msg) \
if (flag == -1) { \
perror(msg); \
continue; \
}
struct thread_arg {
SocketControl *sc;
function<int(char *)> handle_msg;
};
class SocketControl {
public:
int sfd = -1, cfd = -1, flag = -1;
struct sockaddr_un unix_socket;
void socketListening(function<int(char *)> callback) {
debug("starting socket listening");
sfd = socket(AF_UNIX, SOCK_STREAM, 0);
unlink(SOCKET_PATH);
memset(&unix_socket, '\0', sizeof(struct sockaddr_un));
unix_socket.sun_family = AF_UNIX;
strncpy(unix_socket.sun_path, SOCKET_PATH,
sizeof(unix_socket.sun_path) - 1);
bind(sfd, (struct sockaddr *)&unix_socket, sizeof(struct sockaddr_un));
listen(sfd, LISTEN_BACKLOG);
chmod(SOCKET_PATH,S_IRWXU|S_IRWXG|S_IRWXO);
while (true) {
close(cfd);
cfd = accept(sfd, NULL, NULL);
continue_if_error(cfd, "accept");
debug("accept connection: %d", cfd);
// read length
int msg_len;
flag = read(cfd, &msg_len, sizeof(int));
continue_if_error(flag, "read length");
// read msg
char msg[msg_len];
flag = read(cfd, msg, msg_len * sizeof(char));
continue_if_error(flag, "read msg");
msg[msg_len]='\0';
// handle msg
int status = callback(msg);
// send back flag
flag = write(cfd, &status, sizeof(int));
continue_if_error(flag, "write back");
}
}
~SocketControl() {
close(sfd);
close(cfd);
unlink(SOCKET_PATH);
}
static void *startThread(void *arg) {
thread_arg *p = (thread_arg *)arg;
p->sc->socketListening(p->handle_msg);
}
};
#endif