summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--p11p-daemon/src/p11p_remote.erl84
-rw-r--r--p11p-daemon/src/p11p_rpc.hrl3
-rw-r--r--p11p-daemon/src/p11p_server.erl70
3 files changed, 96 insertions, 61 deletions
diff --git a/p11p-daemon/src/p11p_remote.erl b/p11p-daemon/src/p11p_remote.erl
index d0f9184..b27b333 100644
--- a/p11p-daemon/src/p11p_remote.erl
+++ b/p11p-daemon/src/p11p_remote.erl
@@ -10,14 +10,14 @@
%% times out, inform the remote manager (our parent).
%% TODO: "remote" is not a great name and we shouldn't just inherit it
-%% from p11p-kit
+%% from p11p-kit. Let's use "client" or "proxy_client".
-module(p11p_remote).
-behaviour(gen_server).
%% API.
-export([start_link/4]).
--export([request/2, add_to_outbuf/2, stop/2]).
+-export([request/2, stop/2]).
-include("p11p_rpc.hrl").
@@ -30,24 +30,24 @@
port :: port(),
replyto :: pid() | undefined,
timer :: reference() | undefined,
- token :: string(), % Name
- outbuf = <<>> :: binary(),
- msg :: p11rpc:msg() | undefined
+ token :: string(), % Token name.
+ msg :: p11rpc:msg() | undefined,
+ recv_count = 0 :: non_neg_integer(),
+ send_count = 0 :: non_neg_integer()
}).
%% API.
--spec start_link(atom(), string(), string(), list()) -> {ok, pid()} | {error, term()}.
+-spec start_link(atom(), string(), string(), list()) ->
+ {ok, pid()} | {error, term()}.
start_link(ServName, TokName, ModPath, ModEnv) ->
lager:info("~p: p11p_remote starting for ~s", [ServName, ModPath]),
- gen_server:start_link({local, ServName}, ?MODULE, [TokName, ModPath, ModEnv], []).
+ gen_server:start_link({local, ServName}, ?MODULE,
+ [TokName, ModPath, ModEnv], []).
--spec request(pid(), p11rpc_msg()) -> ok.
+-spec request(pid(), p11rpc_msg()) -> {ok, non_neg_integer()}.
request(Remote, Request) ->
gen_server:call(Remote, {request, Request}).
-add_to_outbuf(Remote, Data) ->
- gen_server:call(Remote, {add_to_outbuf, Data}).
-
%% Use stop/1 instead of gen_server:stop/1 if you're uncertain whether
%% Pid is alive or not. An example of when that can happen is when the
%% manager receiving a server_event about a lost client. If the server
@@ -70,37 +70,54 @@ init([TokName, ModPath, ModEnv]) ->
lager:debug("~p: ~s: module: ~s, env: ~p", [self(), RemoteBinPath, ModPath, ModEnv]),
{ok, #state{port = Port, token = TokName}}.
-handle_call({add_to_outbuf, Data}, _From, State) ->
- {reply, ok, do_add_to_outbuf(Data, State)};
-handle_call({request, Request}, {FromPid, _Tag}, #state{port = Port} = S) ->
+handle_call({request, Request}, {FromPid, _Tag},
+ #state{port = Port, send_count = Sent} = S) ->
%%lager:debug("~p: sending request from ~p to remote ~p", [self(), FromPid, Port]),
- State = do_send(do_add_to_outbuf(p11p_rpc:serialise(Request), S)),
- {reply, ok, State#state{replyto = FromPid, timer = start_timer(Port)}};
-handle_call(Request, _From, State) ->
- lager:debug("~p: Unhandled call: ~p~n", [self(), Request]),
+ D = p11p_rpc:serialise(Request),
+ Buf = case Sent of
+ 0 -> <<?RPC_VERSION:8, D/binary>>;
+ _ -> D
+ end,
+ ok = do_send(Port, Buf),
+ {reply, {ok, sizeBuf}, S#state{replyto = FromPid, timer = start_timer(Port),
+ send_count = Sent + 1}};
+
+handle_call(Call, _From, State) ->
+ lager:debug("~p: Unhandled call: ~p~n", [self(), Call]),
{reply, unhandled, State}.
handle_cast({stop, Reason}, State) ->
{stop, Reason, State};
+
handle_cast(Cast, State) ->
lager:debug("~p: unhandled cast: ~p~n", [self(), Cast]),
{noreply, State}.
-%% TODO: dedup code w/ p11p_server
-handle_info({Port, {data, Data}}, #state{replyto = Pid} = State)
+%% Receiving the very first response from remote since it was started.
+handle_info({Port, {data, Data}}, State)
when Port == State#state.port, State#state.msg == undefined ->
- Version = hd(Data), % First octet is version.
- {ok, _BytesAdded} = p11p_server:add_to_clientbuf(Pid, <<Version>>),
- {noreply, handle_remote_data(State, p11p_rpc:new(), tl(Data))};
+ case hd(Data) of % First octet is RPC protocol version.
+ ?RPC_VERSION ->
+ {noreply, handle_remote_data(State, p11p_rpc:new(), tl(Data))};
+ BadVersion ->
+ lager:info("~p: ~p: invalid RPC version: ~p", [self(), Port,
+ BadVersion]),
+ {noreply, State}
+ end;
+
+%% Receiving more data from remote.
handle_info({Port, {data, Data}}, #state{msg = Msg} = State)
when Port == State#state.port ->
{noreply, handle_remote_data(State, Msg, Data)};
+
+%% Remote timed out.
handle_info({timeout, Timer, Port}, #state{token = Tok, replyto = Server} = S)
when Port == S#state.port, Timer == S#state.timer ->
lager:info("~p: rpc request timed out, exiting", [self()]),
p11p_remote_manager:server_event(timeout, [Tok, Server]),
State = S#state{timer = undefined},
{stop, normal, State};
+
handle_info(Info, State) ->
lager:debug("~p: Unhandled info: ~p~n", [self(), Info]),
{noreply, State}.
@@ -114,11 +131,7 @@ code_change(_OldVersion, State, _Extra) ->
{ok, State}.
%% Private
-do_add_to_outbuf(Data, #state{outbuf = OutBuf} = State) ->
- %%lager:debug("~p: adding ~B octets to outbuf", [self(), size(Data)]),
- State#state{outbuf = <<OutBuf/binary, Data/binary>>}.
-
-do_send(#state{port = Port, outbuf = Buf} = State) ->
+do_send(Port, Buf) ->
%%lager:debug("~p: sending ~B octets to remote", [self(), size(Buf)]),
%% case rand:uniform(15) of
@@ -128,17 +141,20 @@ do_send(#state{port = Port, outbuf = Buf} = State) ->
%% port_command(Port, Buf)
%% end,
- port_command(Port, Buf),
- State#state{outbuf = <<>>}.
+ true = port_command(Port, Buf),
+ ok.
-handle_remote_data(#state{replyto = Pid, timer = Timer} = S, MsgIn, DataIn) ->
+handle_remote_data(#state{replyto = Pid, timer = Timer, recv_count = Recv} = S,
+ MsgIn, DataIn) ->
case p11p_rpc:parse(MsgIn, list_to_binary(DataIn)) of
+ {needmore, Msg} ->
+ S#state{msg = Msg};
{done, Msg} ->
cancel_timer(Timer),
{ok, _BytesSent} = p11p_server:reply(Pid, Msg),
- S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer)};
- {needmore, Msg} ->
- S#state{msg = Msg}
+ %% Saving potential data not consumed by parse/2 in new message.
+ S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer),
+ recv_count = Recv + 1}
end.
start_timer(Port) ->
diff --git a/p11p-daemon/src/p11p_rpc.hrl b/p11p-daemon/src/p11p_rpc.hrl
index 8ccb0d1..c511e20 100644
--- a/p11p-daemon/src/p11p_rpc.hrl
+++ b/p11p-daemon/src/p11p_rpc.hrl
@@ -1,6 +1,9 @@
%%% Copyright (c) 2019, Sunet.
%%% See LICENSE for licensing information.
+%% The only RPC version we support.
+-define(RPC_VERSION, 0).
+
-record(p11rpc_msg, {
call_code = -1 :: integer(), % Length is 4
opt_len = -1 :: integer(), % Length is 4
diff --git a/p11p-daemon/src/p11p_server.erl b/p11p-daemon/src/p11p_server.erl
index ff1a8df..b3ffa5c 100644
--- a/p11p-daemon/src/p11p_server.erl
+++ b/p11p-daemon/src/p11p_server.erl
@@ -11,7 +11,7 @@
%% API.
-export([start_link/1]).
--export([add_to_clientbuf/2, reply/2]).
+-export([reply/2]).
%% Genserver callbacks.
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2,
@@ -23,7 +23,9 @@
remote :: pid() | undefined,
socket :: gen_tcp:socket(),
msg :: p11rpc_msg() | undefined,
- clientbuf = <<>> :: binary()
+ recv_count = 0 :: non_neg_integer(),
+ send_count = 0 :: non_neg_integer()
+ %%clientbuf = <<>> :: binary()
}).
%% API.
@@ -31,10 +33,6 @@
start_link(Args) ->
gen_server:start_link(?MODULE, Args, []).
--spec add_to_clientbuf(pid(), binary()) -> {ok, non_neg_integer()}.
-add_to_clientbuf(Pid, Data) ->
- gen_server:call(Pid, {add_to_clientbuf, Data}).
-
-spec reply(pid(), p11rpc_msg()) -> {ok, non_neg_integer()}.
reply(Pid, Response) ->
gen_server:call(Pid, {respond, Response}).
@@ -46,19 +44,21 @@ init([Token, Socket]) ->
gen_server:cast(self(), accept), % Invoke accept, returning a socket in state.
{ok, #state{tokname = Token, socket = Socket}}.
-handle_call({add_to_clientbuf, Data}, _, #state{clientbuf = B} = S) ->
- Buf = <<B/binary, Data/binary>>,
- {reply, {ok, size(Buf)}, S#state{clientbuf = Buf}};
-handle_call({respond, R}, _, #state{socket = Client, clientbuf = B} = S) ->
- Data = p11p_rpc:serialise(R),
- Buf = <<B/binary, Data/binary>>,
- %%lager:debug("~p: sending ~B octets to client as response", [self(), size(Buf)]),
- ok = gen_tcp:send(Client, Buf), % TODO: what about short writes?
- {reply, {ok, size(Buf)}, S#state{clientbuf = <<>>}};
+handle_call({respond, R}, _, #state{socket = Sock, send_count = Sent} = S) ->
+ D = p11p_rpc:serialise(R),
+ Buf = case Sent of
+ 0 -> <<?RPC_VERSION:8, D/binary>>;
+ _ -> D
+ end,
+ %%lager:debug("~p: sending ~B octets as response", [self(), size(Buf)]),
+ ok = gen_tcp:send(Sock, Buf), % TODO: what about short writes?
+ {reply, {ok, size(Buf)}, S#state{send_count = Sent + 1}};
+
handle_call(Call, _, S) ->
lager:debug("~p: Unhandled call: ~p~n", [self(), Call]),
{reply, unhandled, S}.
+%% Wait for new connection.
handle_cast(accept, State = #state{tokname = TokName, socket = ListenSocket}) ->
%% Blocking until client connects or timeout fires.
%% Without a timeout our supervisor cannot terminate us.
@@ -76,24 +76,38 @@ handle_cast(accept, State = #state{tokname = TokName, socket = ListenSocket}) ->
lager:debug("~p: listening socket closed", [self()]),
{stop, normal, State}
end;
+
handle_cast(Cast, State) ->
lager:debug("~p: Unhandled cast: ~p~n", [self(), Cast]),
{noreply, State}.
-handle_info({tcp, _Port, DataIn}, #state{tokname = TokName} = S)
+%% First packet from P11 client.
+handle_info({tcp, Port, DataIn}, #state{tokname = TokName} = S)
when S#state.remote == undefined ->
%%lager:debug("~p: received ~B octets from client on socket ~p, from new client", [self(), size(Data), Port]),
- <<Version:8, Data/binary>> = DataIn,
- Remote = p11p_remote_manager:remote_for_token(TokName),
- p11p_remote:add_to_outbuf(Remote, <<Version>>),
- State = S#state{remote = Remote},
- {noreply, handle_client_data(State, p11p_rpc:new(), Data)};
+ <<RPCVersion:8, Data/binary>> = DataIn,
+ case RPCVersion of
+ ?RPC_VERSION ->
+ {noreply,
+ p11_client_data(
+ S#state{remote = p11p_remote_manager:remote_for_token(TokName)},
+ p11p_rpc:new(),
+ Data)};
+ BadVersion ->
+ lager:info("~p: ~p: invalid RPC version: ~p", [self(), Port,
+ BadVersion]),
+ {stop, bad_proto, S}
+ end;
+
+%% Subsequent packages from P11 client.
handle_info({tcp, _Port, DataIn}, #state{msg = Msg} = S) ->
%%lager:debug("~p: received ~B octets from client on socket ~p, with ~B octets already in buffer", [self(), size(Data), Port, size(Msg#p11rpc_msg.buffer)]),
- {noreply, handle_client_data(S, Msg, DataIn)};
+ {noreply, p11_client_data(S, Msg, DataIn)};
+
handle_info({tcp_closed, Port}, S) ->
lager:debug("~p: socket ~p closed", [self(), Port]),
{stop, normal, S};
+
handle_info(Info, S) ->
lager:debug("~p: Unhandled info: ~p~n", [self(), Info]),
{noreply, S}.
@@ -108,11 +122,13 @@ code_change(_OldVersion, State, _Extra) ->
{ok, State}.
%% Private functions.
-handle_client_data(#state{remote = Remote} = S, MsgIn, DataIn) ->
+p11_client_data(#state{remote = Remote, recv_count = Recv} = S, MsgIn,
+ DataIn) ->
case p11p_rpc:parse(MsgIn, DataIn) of
- {done, Msg} ->
- ok = p11p_remote:request(Remote, Msg),
- S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer)};
{needmore, Msg} ->
- S#state{msg = Msg}
+ S#state{msg = Msg};
+ {done, Msg} ->
+ {ok, _BytesSent} = p11p_remote:request(Remote, Msg),
+ S#state{msg = p11p_rpc:new(Msg#p11rpc_msg.buffer),
+ recv_count = Recv + 1}
end.