%%% Copyright (c) 2016, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(dnssecport).
-behaviour(gen_server).
-export([start_link/0, stop/0]).
-export([validate/1]).
%% gen_server callbacks.
-export([init/1, handle_call/3, terminate/2, handle_cast/2, handle_info/2,
         code_change/3]).

-include_lib("eunit/include/eunit.hrl").

start_link() ->
    gen_server:start_link({local, ?MODULE}, ?MODULE,
                          [code:priv_dir(catlfish) ++ "/dnssecport"], []).

stop() ->
    gen_server:call(?MODULE, stop).

validate(Data) ->
    gen_server:call(?MODULE, {validate, Data}).

-record(state, {port :: port()}).

-spec trust_anchors() -> string().
trust_anchors() ->
    case application:get_env(catlfish, trust_anchors_file) of
        {ok, Filename} -> Filename;
        undefined -> []
    end.

init(Program) ->
    lager:debug("starting dnssec service"),
    Port = create_port(Program, [trust_anchors()]),
    {ok, #state{port = Port}}.

decode_response(Response) ->
    <<Status:16/integer, RRSet/binary>> = Response,
    {ok, Status, dns:decode_rrset(RRSet)}.

handle_call(stop, _From, State) ->
    lager:debug("dnssec stop request received"),
    stop_port(State);
handle_call({validate, Data}, _From, State) ->
    case State#state.port of
        undefined ->
            {reply, {error, noport}, State};
        Port when is_port(Port) ->
            Port ! {self(), {command, Data}},
            receive
                {Port, {data, Response}} ->
                    case decode_response(list_to_binary(Response)) of
                        {ok, 400, [DS | Chain]} ->
                            RRSIG = hd(Chain),
                            R = [dns:encode_rr(dns:canonicalize_dsrr(DS, RRSIG)),
                                 dns:encode_rrset(Chain)],
                            {reply, {ok, R}, State};
                        {ok, Error, _} ->
                            lager:debug("DNSSEC validation failed with ~p",
                                        [Error]),
                            {reply, {error, Error}, State}
                    end;
                {Port, {exit_status, ExitStatus}} ->
                    lager:error("dnssec port ~p exiting with status ~p",
                                [Port, ExitStatus]),
                    {stop, portexit, State#state{port = undefined}}
            after
                3000 ->
                    lager:error("dnssec port timeout"),
                    {stop, timeout, State}
            end
    end.

handle_info(_Info, State)           ->
    {noreply, State}.

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

handle_cast(_Request, State)        ->
    {noreply, State}.

terminate(Reason, _State)           ->
    lager:info("dnssec port terminating: ~p", [Reason]),
    ok.

%%%%%%%%%%%%%%%%%%%%
create_port(Program, Args)          ->
    open_port({spawn_executable, Program},
              [{args, Args},
               exit_status,             % Let us know if process dies.
               {packet, 4}]).

stop_port(State) ->
    Port = State#state.port,
    Port ! {self(), close},
    receive
        {Port, closed} ->
            {stop, closed, State#state{port = undefined}};
        {Port, Msg} ->
            lager:debug("message received from dying port: ~p", [Msg]),
            {stop, unknown, State#state{port = undefined}}
    after
        2000 ->
            lager:error("dnssec port ~p refuses to die", [Port]),
            {stop, timeout, State}
    end.

%%%%%%%%%%%%%%%%%%%%
%% Unit tests.
-define(TA_FILE, "test/testdata/dnssec/trust_anchors").
-define(REQ1_FILE, "test/testdata/dnssec/req-basic").
-define(REQ2_FILE, "test/testdata/dnssec/req-lowttl").

start_test_port() ->
    create_port("priv/dnssecport", [?TA_FILE]).

stop_test_port(Port) ->
    {stop, closed, _State} = stop_port(#state{port = Port}),
    ok.

read_submission_from_file(Filename) ->
    {ok, Data} = file:read_file(Filename),
    dns:decode_rrset(Data).

read_dec_enc_test_() ->
    DecodedRRset = read_submission_from_file(?REQ1_FILE),
    {ok, FileContent} = file:read_file(?REQ1_FILE),
    [?_assertEqual(FileContent, dns:encode_rrset(DecodedRRset))].

%% TODO: These tests are a bit lame. Room for improvement!
full_test_() ->
    {setup,
     fun() ->
             start_test_port() end,
     fun(Port) ->
             stop_test_port(Port) end,
     fun(Port) ->
             R1 = handle_call({validate, read_submission_from_file(?REQ1_FILE)},
                             self(), #state{port = Port}),
             R2 = handle_call({validate, read_submission_from_file(?REQ2_FILE)},
                             self(), #state{port = Port}),
             {reply, {ok, [DSBin | _ChainBin]}, _} = R2,
             {DS, <<>>} = dns:decode_rr(DSBin),
             [
              ?_assertMatch({reply, {ok, _}, _State}, R1),
              ?_assertMatch({reply, {ok, _}, _State}, R2),
              ?_assertMatch({rr, _Name, _Type, _Class, 3600, _RDATA}, DS)
             ] end
    }.

%% start_test_port(TestType) ->
%%     Port = create_port("priv/dnssecport", ["--testmode", atom_to_list(TestType)]),
%%     ?debugFmt("Port: ~p", [Port]),
%%     Port.
%% stop_test_port(Port) ->
%%     {stop, closed, _State} = stop_port(#state{port = Port}),
%%     ok.

%% err_test_() ->
%%     {setup,
%%      fun() -> start_test_port(err) end,
%%      fun(Port) -> stop_test_port(Port) end,
%%      fun(Port)  ->
%%              R = handle_call({validate, [<<"invalid-DS">>, []]},
%%                              self(), #state{port = Port}),
%%              [
%%               ?_assertMatch({reply, {error, "err"}, _State}, R)
%%              ]
%%      end}.

%% ok_test_() ->
%%     {setup,
%%      fun() -> start_test_port(ok) end,
%%      fun(Port) -> stop_test_port(Port) end,
%%      fun(Port)  ->
%%              R = handle_call({validate, [<<"invalid-DS">>, []]},
%%                              self(), #state{port = Port}),
%%              [
%%               ?_assertMatch({reply, ok, _State}, R)
%%              ]
%%      end}.