Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

redbean: add tls socket lua binding #1279

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
42 changes: 42 additions & 0 deletions tool/net/definitions.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4982,6 +4982,48 @@ unix = {
X_OK = nil
}

---@class TlsContext
---@field connect fun(self: TlsContext, server_name: string, server_port: string): boolean, string?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of API design, have you considered this?

tls = require "tls"
conn = tls.TlsClient(ResolveIp("google.com"), 80) -- returns object of TlsClient class
conn.write("GET / HTTP/1.0\r\n\r\n")
print(conn.read())
conn.close()

That would be more similar and composable with other redbean APIs. For example, we like to pass IPs as uint32. You wouldn't need to maintain a state machine with this design. You could have functions or variables associated with the tls module for doing context-wide configuration, like whether or not SSL client verification should be enabled. You could also have that default to the redbean settings.

Another thing you might do that's even better is:

tls = require "tls"
unix = require "unix"

fd = assert(unix.socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP))
assert(unix.connect(fd, ResolveIp("google.com"), 80))
conn = assert(tls.TlsClient(fd)) -- returns object of TlsSocket class
assert(conn.write("GET / HTTP/1.0\r\n\r\n"))
response = assert(conn.read())
print(response)
assert(unix.close(fd))

I noticed you're drawing a lot of influence from mbedTLS's net_sockets.c API. I don't like that API. I don't think it's very good. If you use that abstraction, then you lose the ability to compose with redbean APIs like ResolveIp(), unix.poll(), etc. For examples of how I've made mbedTLS work with raw file descriptors, see tool/curl/curl.c and tool/build/lib/eztls.c.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API indeed seems more natural. I will make the modifications to get as close as possible to this example. Thank you

Copy link
Author

@chamot1111 chamot1111 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use now:

---@class tls
local tls = {}

--- Creates a new TLS client.
---@param fd integer File descriptor of the socket
---@param verify? boolean Whether to verify the server's certificate (default: true)
---@param timeout? integer Read timeout in milliseconds (default: 0, no timeout)
---@return TlsContext|nil context
---@return string? error
function tls.TlsClient(fd, verify, timeout) end

--- Writes data to the TLS connection.
---@param context TlsContext
---@param data string
---@return integer bytes_written
---@return string? error
function tls:write(data) end

--- Reads data from the TLS connection.
---@param context TlsContext
---@param bufsiz? integer Maximum number of bytes to read (default: BUFSIZ)
---@return string? data
---@return string? error
function tls:read(bufsiz) end

---@field write fun(self: TlsContext, data: string): integer, string?
---@field read fun(self: TlsContext, bufsiz?: integer): string?, string?
---@field close fun(self: TlsContext)

---@class tls
local tls = {}

--- Creates a new TLS socket.
---@param verify? boolean Whether to verify the server's certificate (default: true)
---@param timeout? integer Read timeout in milliseconds (default: 0, no timeout)
---@return TlsContext|nil context
---@return string? error
function tls.socket(verify, timeout) end

--- Connects to a server using TLS.
---@param context TlsContext
---@param server_name string
---@param server_port string
---@return boolean success
---@return string? error
function tls:connect(server_name, server_port) end

--- Writes data to the TLS connection.
---@param context TlsContext
---@param data string
---@return integer bytes_written
---@return string? error
function tls:write(data) end

--- Reads data from the TLS connection.
---@param context TlsContext
---@param bufsiz? integer Maximum number of bytes to read (default: BUFSIZ)
---@return string? data
---@return string? error
function tls:read(bufsiz) end

--- Closes the TLS connection.
---@param context TlsContext
function tls:close() end

--- Opens file.
---
--- Returns a file descriptor integer that needs to be closed, e.g.
Expand Down
293 changes: 293 additions & 0 deletions tool/net/ltls.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
#include "libc/intrin/kprintf.h"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably remove this line.

I'm also reasonably certain this doesn't need to be a .inc file. LuaFetch() had to be one because it's dug in like a tick with all the global variables defined in redbean.c. But I see that here you're going for a more conventional library API. For example, rather than checking the sslclientverify you're just accepting that as a parameter. That is probably the right thing to do, honestly.

So could you, if possible, make this ltls.c and update tool/net/BUILD.mk appropriately? Be sure to put the standard copyright header at the top.

static const char *const tls_meta = ":mbedtls";

typedef enum {
TLS_STATE_INIT,
TLS_STATE_CONNECTED,
TLS_STATE_CLOSED
} TlsConnectionState;

typedef struct {
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_net_context server_fd;
int ref; // Reference to self in the Lua registry
TlsConnectionState connection_state;
char *read_buffer;
size_t read_buffer_size;
} TlsContext;

static TlsContext **checktls(lua_State *L) {
TlsContext **tls = (TlsContext **)luaL_checkudata(L, 1, tls_meta);
if (tls == NULL || *tls == NULL)
luaL_typeerror(L, 1, tls_meta);
return tls;
}

static int tls_gc(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;

if (tls) {
if (tls->connection_state != TLS_STATE_CLOSED) {
mbedtls_net_free(&tls->server_fd);
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
}
mbedtls_ssl_free(&tls->ssl);
luaL_unref(L, LUA_REGISTRYINDEX, tls->ref);
free(tls->read_buffer);
free(tls);
*tlsp = NULL;
}
return 0;
}

static int tls_socket(lua_State *L) {
if (!sslinitialized) {
TlsInit();
}

TlsContext **tlsp = (TlsContext **)lua_newuserdata(L, sizeof(TlsContext *));
*tlsp = NULL;

luaL_getmetatable(L, tls_meta);
lua_setmetatable(L, -2);

TlsContext *tls = (TlsContext *)malloc(sizeof(TlsContext));
if (tls == NULL) {
lua_pushnil(L);
lua_pushstring(L, "Failed to allocate memory for TLS context");
return 2;
}
*tlsp = tls;

tls->connection_state = TLS_STATE_INIT;
tls->read_buffer = NULL;
tls->read_buffer_size = 0;

mbedtls_net_init(&tls->server_fd);
mbedtls_ssl_init(&tls->ssl);
mbedtls_ssl_config_init(&tls->conf);
mbedtls_ctr_drbg_init(&tls->ctr_drbg);
mbedtls_entropy_init(&tls->entropy);
int sslVerify = lua_isnone(L, 1) ? 1 : lua_toboolean(L, 1);
if (sslVerify) {
mbedtls_ssl_conf_ca_chain(&tls->conf, GetSslRoots(), 0);
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
} else {
mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE);
}

int timeout = lua_isnone(L, 2) ? 0 : luaL_checkinteger(L, 2);
mbedtls_ssl_conf_read_timeout(&tls->conf, timeout);

const char *pers = "tls_socket";
int ret;
if ((ret = mbedtls_ctr_drbg_seed(&tls->ctr_drbg, mbedtls_entropy_func,
&tls->entropy, (const unsigned char *)pers,
strlen(pers))) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ctr_drbg_seed returned %d", ret);
return 2;
}

if ((ret = mbedtls_ssl_setup(&tls->ssl, &tls->conf)) != 0) {
free(tls);
*tlsp = NULL;
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_setup returned %d", ret);
return 2;
}

tls->ref = luaL_ref(L, LUA_REGISTRYINDEX);
lua_rawgeti(L, LUA_REGISTRYINDEX, tls->ref);

return 1;
}

static void my_debug(void *ctx, int level, const char *file, int line,
const char *str) {
((void)level);
fprintf((FILE *)ctx, "%s:%04d: %s", file, line, str);
fflush((FILE *)ctx);
}

static int tls_connect(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
const char *server_name = luaL_checkstring(L, 2);
const char *server_port = luaL_checkstring(L, 3);

int ret;
if ((ret = mbedtls_net_connect(&tls->server_fd, server_name, server_port,
MBEDTLS_NET_PROTO_TCP)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "connect failed: %d", ret);
return 2;
}

if ((ret = mbedtls_ssl_config_defaults(&tls->conf, MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_config_defaults failed: %d", ret);
return 2;
}

// mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE); // only for
// test, conf mbedtls_x509_crt_init instead

mbedtls_ssl_conf_rng(&tls->conf, mbedtls_ctr_drbg_random, &tls->ctr_drbg);
mbedtls_ssl_conf_dbg(&tls->conf, my_debug, stdout);

if ((ret = mbedtls_ssl_set_hostname(&tls->ssl, server_name)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "mbedtls_ssl_set_hostname failed: %d", ret);
return 2;
}

mbedtls_ssl_set_bio(&tls->ssl, &tls->server_fd, mbedtls_net_send, NULL,
mbedtls_net_recv_timeout);

if ((ret = mbedtls_ssl_handshake(&tls->ssl)) != 0) {
lua_pushnil(L);
lua_pushfstring(L, "SSL handshake failed: %d", ret);
return 2;
}

tls->connection_state = TLS_STATE_CONNECTED;

lua_pushboolean(L, 1);
return 1;
}

static int tls_write(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
size_t len;
const char *data = luaL_checklstring(L, 2, &len);
int ret = mbedtls_ssl_write(&tls->ssl, (const unsigned char *)data, len);

if (ret < 0) {
lua_pushnil(L);
lua_pushfstring(L, "SSL write failed: %d", ret);
return 2;
}

lua_pushinteger(L, ret);
return 1;
}

static int tls_read(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
lua_Integer bufsiz = luaL_optinteger(L, 2, BUFSIZ);
bufsiz = MIN(bufsiz, 0x7ffff000);

if (tls->read_buffer == NULL || tls->read_buffer_size < bufsiz) {
char *new_buffer = realloc(tls->read_buffer, bufsiz);
if (new_buffer == NULL) {
lua_pushnil(L);
lua_pushstring(L, "Memory allocation failed");
return 2;
}
tls->read_buffer = new_buffer;
tls->read_buffer_size = bufsiz;
}

int ret =
mbedtls_ssl_read(&tls->ssl, (unsigned char *)tls->read_buffer, bufsiz);

if (ret > 0) {
lua_pushlstring(L, tls->read_buffer, ret);
return 1;
} else if (ret == 0) {
// End of file
lua_pushnil(L);
return 1;
} else {
lua_pushnil(L);
if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really support non-blocking i/o?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I don't, i remove it

lua_pushstring(L, "EAGAIN");
} else {
lua_pushfstring(L, "Read error: %d", ret);
}
return 2;
}
}

static int tls_close(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;

mbedtls_net_free(&tls->server_fd);
mbedtls_ssl_free(&tls->ssl);
mbedtls_ssl_config_free(&tls->conf);
mbedtls_ctr_drbg_free(&tls->ctr_drbg);
mbedtls_entropy_free(&tls->entropy);
free(tls->read_buffer);
tls->read_buffer = NULL;
tls->read_buffer_size = 0;
tls->connection_state = TLS_STATE_CLOSED;

return 0;
}

static int tls_tostring(lua_State *L) {
TlsContext **tlsp = checktls(L);
TlsContext *tls = *tlsp;
const char *state_str;

switch (tls->connection_state) {
case TLS_STATE_INIT:
state_str = "initialized";
break;
case TLS_STATE_CONNECTED:
state_str = "connected";
break;
case TLS_STATE_CLOSED:
state_str = "closed";
break;
default:
state_str = "unknown";
}

lua_pushfstring(L, "TLS connection (%p): %s", tls, state_str);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With redbean what we like to do is have a __repr method so when an object is printed on the redbean shell, it'll print the code you'd type to create that object. Then the __tostring method will usually be the same thing as __repr. I don't think it's worth showing state here, because the state you're showing is just the internal class state and might not reflect the current status in the OS of the socket.

Copy link
Author

@chamot1111 chamot1111 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I will use:

static int tls_tostring(lua_State *L) {
  TlsContext **tlsp = checktls(L);
  TlsContext *tls = *tlsp;

  lua_pushfstring(L, "tls.TlsClient({fd=%d})", tls->fd);
  return 1;
}

static const struct luaL_Reg tls_methods[] = {
    {"write", tls_write},
    {"read", tls_read},
    {"close", tls_close},
    {"__gc", tls_gc},
    {"__tostring", tls_tostring},
    {"__repr", tls_tostring},
    {NULL, NULL}
};

return 1;
}

static const struct luaL_Reg tls_methods[] = {{"connect", tls_connect},
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put the elements on their own lines.

{"write", tls_write},
{"read", tls_read},
{"close", tls_close},
{"__gc", tls_gc},
{"__tostring", tls_tostring},
{NULL, NULL}};

static const struct luaL_Reg tlslib[] = {{"socket", tls_socket}, {NULL, NULL}};

static void create_meta(lua_State *L, const char *name,
const luaL_Reg *methods) {
luaL_newmetatable(L, name);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
luaL_setfuncs(L, methods, 0);
}

LUALIB_API int luaopen_tls(lua_State *L) {
create_meta(L, tls_meta, tls_methods);

luaL_newlib(L, tlslib);

lua_pushvalue(L, -1);
lua_setmetatable(L, -2);

return 1;
}
5 changes: 5 additions & 0 deletions tool/net/redbean.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
#include "third_party/mbedtls/ssl_ticket.h"
#include "third_party/mbedtls/x509.h"
#include "third_party/mbedtls/x509_crt.h"
#include "third_party/mbedtls/entropy.h"
#include "third_party/musl/netdb.h"
#include "third_party/zlib/zlib.h"
#include "tool/build/lib/case.h"
Expand Down Expand Up @@ -3977,6 +3978,7 @@ static int LuaNilTlsError(lua_State *L, const char *s, int r) {
}

#include "tool/net/fetch.inc"
#include "tool/net/ltls.inc"

static int LuaGetDate(lua_State *L) {
lua_pushinteger(L, shared->nowish.tv_sec);
Expand Down Expand Up @@ -5401,6 +5403,9 @@ static const luaL_Reg kLuaFuncs[] = {
static const luaL_Reg kLuaLibs[] = {
{"argon2", luaopen_argon2}, //
{"lsqlite3", luaopen_lsqlite3}, //
#ifndef UNSECURE
{"tls", luaopen_tls}, //
#endif
{"maxmind", LuaMaxmind}, //
{"finger", LuaFinger}, //
{"path", LuaPath}, //
Expand Down
Loading