diff options
-rw-r--r-- | p11p-daemon/src/p11p_remote.erl | 84 | ||||
-rw-r--r-- | p11p-daemon/src/p11p_rpc.hrl | 3 | ||||
-rw-r--r-- | p11p-daemon/src/p11p_server.erl | 70 |
3 files changed, 96 insertions, 61 deletions
diff --git a/p11p-daemon/src/p11p_remote.erl b/p11p-daemon/src/p11p_remote.erl index d0f9184..b27b333 100644 --- a/p11p-daemon/src/p11p_remote.erl +++ b/p11p-daemon/src/p11p_remote.erl @@ -10,14 +10,14 @@ %% times out, inform the remote manager (our parent). %% TODO: "remote" is not a great name and we shouldn't just inherit it -%% from p11p-kit +%% from p11p-kit. Let's use "client" or "proxy_client". -module(p11p_remote). -behaviour(gen_server). %% API. -export([start_link/4]). --export([request/2, add_to_outbuf/2, stop/2]). +-export([request/2, stop/2]). -include("p11p_rpc.hrl"). @@ -30,24 +30,24 @@ port :: port(), replyto :: pid() | undefined, timer :: reference() | undefined, - token :: string(), % Name - outbuf = <<>> :: binary(), - msg :: p11rpc:msg() | undefined + token :: string(), % Token name. + msg :: p11rpc:msg() | undefined, + recv_count = 0 :: non_neg_integer(), + send_count = 0 :: non_neg_integer() }). %% API. --spec start_link(atom(), string(), string(), list()) -> {ok, pid()} | {error, term()}. +-spec start_link(atom(), string(), string(), list()) -> + {ok, pid()} | {error, term()}. start_link(ServName, TokName, ModPath, ModEnv) -> lager:info("~p: p11p_remote starting for ~s", [ServName, ModPath]), - gen_server:start_link({local, ServName}, ?MODULE, [TokName, ModPath, ModEnv], []). + gen_server:start_link({local, ServName}, ?MODULE, + [TokName, ModPath, ModEnv], []). --spec request(pid(), p11rpc_msg()) -> ok. +-spec request(pid(), p11rpc_msg()) -> {ok, non_neg_integer()}. request(Remote, Request) -> gen_server:call(Remote, {request, Request}). -add_to_outbuf(Remote, Data) -> - gen_server:call(Remote, {add_to_outbuf, Data}). - %% Use stop/1 instead of gen_server:stop/1 if you're uncertain whether %% Pid is alive or not. An example of when that can happen is when the %% manager receiving a server_event about a lost client. If the server @@ -70,37 +70,54 @@ init([TokName, ModPath, ModEnv]) -> lager:debug("~p: ~s: module: ~s, env: ~p", [self(), RemoteBinPath, ModPath, ModEnv]), {ok, #state{port = Port, token = TokName}}. -handle_call({add_to_outbuf, Data}, _From, State) -> - {reply, ok, do_add_to_outbuf(Data, State)}; -handle_call({request, Request}, {FromPid, _Tag}, #state{port = Port} = S) -> +handle_call({request, Request}, {FromPid, _Tag}, + #state{port = Port, send_count = Sent} = S) -> %%lager:debug("~p: sending request from ~p to remote ~p", [self(), FromPid, Port]), - State = do_send(do_add_to_outbuf(p11p_rpc:serialise(Request), S)), - {reply, ok, State#state{replyto = FromPid, timer = start_timer(Port)}}; -handle_call(Request, _From, State) -> - lager:debug("~p: Unhandled call: ~p~n", [self(), Request]), + D = p11p_rpc:serialise(Request), + Buf = case Sent of + 0 -> <<?RPC_VERSION:8, D/binary>>; + _ -> D + end, + ok = do_send(Port, Buf), + {reply, {ok, sizeBuf}, S#state{replyto = FromPid, timer = start_timer(Port), + send_count = Sent + 1}}; + +handle_call(Call, _From, State) -> + lager:debug("~p: Unhandled call: ~p~n", [self(), Call]), {reply, unhandled, State}. handle_cast({stop, Reason}, State) -> {stop, Reason, State}; + handle_cast(Cast, State) -> lager:debug("~p: unhandled cast: ~p~n", [self(), Cast]), {noreply, State}. -%% TODO: dedup code w/ p11p_server -handle_info({Port, {data, Data}}, #state{replyto = Pid} = State) +%% Receiving the very first response from remote since it was started. +handle_info({Port, {data, Data}}, State) when Port == State#state.port, State#state.msg == undefined -> - Version = hd(Data), % First octet is version. - {ok, _BytesAdded} = p11p_server:add_to_clientbuf(Pid, <<Version>>), - {noreply, handle_remote_data(State, p11p_rpc:new(), tl(Data))}; + case hd(Data) of % First octet is RPC protocol version. + ?RPC_VERSION -> + {noreply, handle_remote_data(State, p11p_rpc:new(), tl(Data))}; + BadVersion -> + lager:info("~p: ~p: invalid RPC version: ~p", [self(), Port, + BadVersion]), + {noreply, State} + end; + +%% Receiving more data from remote. handle_info({Port, {data, Data}}, #state{msg = Msg} = State) when Port == State#state.port -> {noreply, handle_remote_data(State, Msg, Data)}; + +%% Remote timed out. handle_info({timeout, Timer, Port}, #state{token = Tok, replyto = Server} = S) when Port == S#state.port, Timer == S#state.timer -> lager:info("~p: rpc request timed out, exiting", [self()]), p11p_remote_manager:server_event(timeout, [Tok, Server]), State = S#state{timer = undefined}, {stop, normal, State}; + handle_info(Info, State) -> lager:debug("~p: Unhandled info: ~p~n", [self(), Info]), {noreply, State}. @@ -114,11 +131,7 @@ code_change(_OldVersion, State, _Extra) -> {ok, State}. %% Private -do_add_to_outbuf(Data, #state{outbuf = OutBuf} = State) -> - %%lager:debug("~p: adding ~B octets to outbuf", [self(), size(Data)]), - State#state{outbuf = <<OutBuf/binary, Data/binary>>}. - -do_send(#state{port = Port, outbuf = Buf} = State) -> +do_send(Port, Buf) -> %%lager:debug("~p: sending ~B octets to remote", [self(), size(Buf)]), %% case rand:uniform(15) of @@ -128,17 +141,20 @@ do_send(#state{port = Port, outbuf = Buf} = State) -> %% port_command(Port, Buf) %% end, - port_command(Port, Buf), - State#state{outbuf = <<>>}. + true = port_command(Port, Buf), + ok. -handle_remote_data(#state{replyto = Pid, timer = Timer} = S, MsgIn, DataIn) -> +handle_remote_data(#state{replyto = Pid, timer = Timer, recv_count = Recv} = S, + MsgIn, DataIn) -> case p11p_rpc:parse(MsgIn, list_to_binary(DataIn)) of + {needmore, Msg} -> + S#state{msg = Msg}; {done, Msg} -> cancel_timer(Timer), {ok, _BytesSent} = p11p_server:reply(Pid, Msg), - S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer)}; - {needmore, Msg} -> - S#state{msg = Msg} + %% Saving potential data not consumed by parse/2 in new message. + S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer), + recv_count = Recv + 1} end. start_timer(Port) -> diff --git a/p11p-daemon/src/p11p_rpc.hrl b/p11p-daemon/src/p11p_rpc.hrl index 8ccb0d1..c511e20 100644 --- a/p11p-daemon/src/p11p_rpc.hrl +++ b/p11p-daemon/src/p11p_rpc.hrl @@ -1,6 +1,9 @@ %%% Copyright (c) 2019, Sunet. %%% See LICENSE for licensing information. +%% The only RPC version we support. +-define(RPC_VERSION, 0). + -record(p11rpc_msg, { call_code = -1 :: integer(), % Length is 4 opt_len = -1 :: integer(), % Length is 4 diff --git a/p11p-daemon/src/p11p_server.erl b/p11p-daemon/src/p11p_server.erl index ff1a8df..b3ffa5c 100644 --- a/p11p-daemon/src/p11p_server.erl +++ b/p11p-daemon/src/p11p_server.erl @@ -11,7 +11,7 @@ %% API. -export([start_link/1]). --export([add_to_clientbuf/2, reply/2]). +-export([reply/2]). %% Genserver callbacks. -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, @@ -23,7 +23,9 @@ remote :: pid() | undefined, socket :: gen_tcp:socket(), msg :: p11rpc_msg() | undefined, - clientbuf = <<>> :: binary() + recv_count = 0 :: non_neg_integer(), + send_count = 0 :: non_neg_integer() + %%clientbuf = <<>> :: binary() }). %% API. @@ -31,10 +33,6 @@ start_link(Args) -> gen_server:start_link(?MODULE, Args, []). --spec add_to_clientbuf(pid(), binary()) -> {ok, non_neg_integer()}. -add_to_clientbuf(Pid, Data) -> - gen_server:call(Pid, {add_to_clientbuf, Data}). - -spec reply(pid(), p11rpc_msg()) -> {ok, non_neg_integer()}. reply(Pid, Response) -> gen_server:call(Pid, {respond, Response}). @@ -46,19 +44,21 @@ init([Token, Socket]) -> gen_server:cast(self(), accept), % Invoke accept, returning a socket in state. {ok, #state{tokname = Token, socket = Socket}}. -handle_call({add_to_clientbuf, Data}, _, #state{clientbuf = B} = S) -> - Buf = <<B/binary, Data/binary>>, - {reply, {ok, size(Buf)}, S#state{clientbuf = Buf}}; -handle_call({respond, R}, _, #state{socket = Client, clientbuf = B} = S) -> - Data = p11p_rpc:serialise(R), - Buf = <<B/binary, Data/binary>>, - %%lager:debug("~p: sending ~B octets to client as response", [self(), size(Buf)]), - ok = gen_tcp:send(Client, Buf), % TODO: what about short writes? - {reply, {ok, size(Buf)}, S#state{clientbuf = <<>>}}; +handle_call({respond, R}, _, #state{socket = Sock, send_count = Sent} = S) -> + D = p11p_rpc:serialise(R), + Buf = case Sent of + 0 -> <<?RPC_VERSION:8, D/binary>>; + _ -> D + end, + %%lager:debug("~p: sending ~B octets as response", [self(), size(Buf)]), + ok = gen_tcp:send(Sock, Buf), % TODO: what about short writes? + {reply, {ok, size(Buf)}, S#state{send_count = Sent + 1}}; + handle_call(Call, _, S) -> lager:debug("~p: Unhandled call: ~p~n", [self(), Call]), {reply, unhandled, S}. +%% Wait for new connection. handle_cast(accept, State = #state{tokname = TokName, socket = ListenSocket}) -> %% Blocking until client connects or timeout fires. %% Without a timeout our supervisor cannot terminate us. @@ -76,24 +76,38 @@ handle_cast(accept, State = #state{tokname = TokName, socket = ListenSocket}) -> lager:debug("~p: listening socket closed", [self()]), {stop, normal, State} end; + handle_cast(Cast, State) -> lager:debug("~p: Unhandled cast: ~p~n", [self(), Cast]), {noreply, State}. -handle_info({tcp, _Port, DataIn}, #state{tokname = TokName} = S) +%% First packet from P11 client. +handle_info({tcp, Port, DataIn}, #state{tokname = TokName} = S) when S#state.remote == undefined -> %%lager:debug("~p: received ~B octets from client on socket ~p, from new client", [self(), size(Data), Port]), - <<Version:8, Data/binary>> = DataIn, - Remote = p11p_remote_manager:remote_for_token(TokName), - p11p_remote:add_to_outbuf(Remote, <<Version>>), - State = S#state{remote = Remote}, - {noreply, handle_client_data(State, p11p_rpc:new(), Data)}; + <<RPCVersion:8, Data/binary>> = DataIn, + case RPCVersion of + ?RPC_VERSION -> + {noreply, + p11_client_data( + S#state{remote = p11p_remote_manager:remote_for_token(TokName)}, + p11p_rpc:new(), + Data)}; + BadVersion -> + lager:info("~p: ~p: invalid RPC version: ~p", [self(), Port, + BadVersion]), + {stop, bad_proto, S} + end; + +%% Subsequent packages from P11 client. handle_info({tcp, _Port, DataIn}, #state{msg = Msg} = S) -> %%lager:debug("~p: received ~B octets from client on socket ~p, with ~B octets already in buffer", [self(), size(Data), Port, size(Msg#p11rpc_msg.buffer)]), - {noreply, handle_client_data(S, Msg, DataIn)}; + {noreply, p11_client_data(S, Msg, DataIn)}; + handle_info({tcp_closed, Port}, S) -> lager:debug("~p: socket ~p closed", [self(), Port]), {stop, normal, S}; + handle_info(Info, S) -> lager:debug("~p: Unhandled info: ~p~n", [self(), Info]), {noreply, S}. @@ -108,11 +122,13 @@ code_change(_OldVersion, State, _Extra) -> {ok, State}. %% Private functions. -handle_client_data(#state{remote = Remote} = S, MsgIn, DataIn) -> +p11_client_data(#state{remote = Remote, recv_count = Recv} = S, MsgIn, + DataIn) -> case p11p_rpc:parse(MsgIn, DataIn) of - {done, Msg} -> - ok = p11p_remote:request(Remote, Msg), - S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer)}; {needmore, Msg} -> - S#state{msg = Msg} + S#state{msg = Msg}; + {done, Msg} -> + {ok, _BytesSent} = p11p_remote:request(Remote, Msg), + S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer), + recv_count = Recv + 1} end. |