%%% Copyright (c) 2014-2015, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(catlfish).
-export([add_chain/3, entries/2, entry_and_proof/2]).
-export([known_roots/0, update_known_roots/0]).
-export([init_cache_table/0]).
-export([entryhash_from_entry/1, verify_entry/1, verify_entry/2]).
-include_lib("eunit/include/eunit.hrl").

-define(PROTOCOL_VERSION, 0).

-type signature_type() :: certificate_timestamp | tree_hash | test. % uint8
-type entry_type() :: x509_entry | precert_entry | test. % uint16
-type leaf_type() :: timestamped_entry | test.           % uint8
-type leaf_version() :: v1 | v2.                         % uint8

-record(mtl, {leaf_version :: leaf_version(),
              leaf_type :: leaf_type(),
              entry :: timestamped_entry()}).
-type mtl() :: #mtl{}.

-record(timestamped_entry, {timestamp :: integer(),
                            entry_type :: entry_type(),
                            signed_entry :: signed_x509_entry() |
                                            signed_precert_entry(),
                            extensions = <<>> :: binary()}).
-type timestamped_entry() :: #timestamped_entry{}.

-record(signed_x509_entry, {asn1_cert :: binary()}).
-type signed_x509_entry() :: #signed_x509_entry{}.
-record(signed_precert_entry, {issuer_key_hash :: binary(),
                               tbs_certificate :: binary()}).
-type signed_precert_entry() :: #signed_precert_entry{}.

-spec serialise(mtl() | timestamped_entry() |
                signed_x509_entry() | signed_precert_entry()) -> binary().
%% @doc Serialise a MerkleTreeLeaf as per RFC6962 Section 3.4.
serialise(#mtl{leaf_version = LeafVersion,
               leaf_type = LeafType,
               entry = TimestampedEntry}) ->
    list_to_binary(
      [serialise_leaf_version(LeafVersion),
       serialise_leaf_type(LeafType),
       serialise(TimestampedEntry)]);
%% @doc Serialise a TimestampedEntry as per RFC6962 Section 3.4.
serialise(#timestamped_entry{timestamp = Timestamp,
                             entry_type = EntryType,
                             signed_entry = SignedEntry,
                             extensions = Extensions}) ->
    list_to_binary(
      [<<Timestamp:64>>,
       serialise_entry_type(EntryType),
       serialise(SignedEntry),
       encode_tls_vector(Extensions, 2)]);
%% @doc Serialise an ASN1.Cert as per RFC6962 Section 3.1.
serialise(#signed_x509_entry{asn1_cert = Cert}) ->
    encode_tls_vector(Cert, 3);
%% @doc Serialise a PreCert as per RFC6962 Section 3.2.
serialise(#signed_precert_entry{
             issuer_key_hash = IssuerKeyHash,
             tbs_certificate = TBSCertificate}) when is_binary(IssuerKeyHash),
                                                     size(IssuerKeyHash) == 32 ->
    list_to_binary(
      [IssuerKeyHash,
       encode_tls_vector(TBSCertificate, 3)]).

serialise_leaf_version(v1) ->
    <<0:8>>;
serialise_leaf_version(v2) ->
    <<1:8>>.
deserialise_leaf_version(<<0:8>>) ->
    v1;
deserialise_leaf_version(<<1:8>>) ->
    v2.

serialise_leaf_type(timestamped_entry) ->
    <<0:8>>.
deserialise_leaf_type(<<0:8>>) ->
    timestamped_entry.

serialise_entry_type(x509_entry) ->
    <<0:16>>;
serialise_entry_type(precert_entry) ->
    <<1:16>>.
deserialise_entry_type(<<0:16>>) ->
    x509_entry;
deserialise_entry_type(<<1:16>>) ->
    precert_entry.

-spec serialise_signature_type(signature_type()) -> binary().
serialise_signature_type(certificate_timestamp) ->
    <<0:8>>.

calc_sct(TimestampedEntry) ->
    plop:serialise(
      plop:spt(list_to_binary([<<?PROTOCOL_VERSION:8>>,
                               serialise_signature_type(certificate_timestamp),
                               serialise(TimestampedEntry)]))).

get_sct(Hash, TimestampedEntry) ->
    case application:get_env(catlfish, sctcache_root_path) of
        {ok, RootPath} ->
            case perm:readfile(RootPath, Hash) of
                Contents when is_binary(Contents) ->
                    Contents;
                noentry ->
                    SCT = calc_sct(TimestampedEntry),
                    ok = perm:ensurefile_nosync(RootPath, Hash, SCT),
                    SCT
            end;
        _ ->
            calc_sct(TimestampedEntry)
    end.

add_to_db(Type, LeafCert, CertChain, EntryHash) ->
    EntryType = case Type of
                    normal -> x509_entry;
                    precert -> precert_entry
                end,
    Timestamp = plop:generate_timestamp(),
    TSE = timestamped_entry(Timestamp, EntryType, LeafCert, CertChain),
    MTLText = serialise(#mtl{leaf_version = v1,
                             leaf_type = timestamped_entry,
                             entry = TSE}),
    MTLHash = ht:leaf_hash(MTLText),
    ExtraData =
        case Type of
            normal -> CertChain;
            precert -> [LeafCert | CertChain]
        end,
    LogEntry =
        list_to_binary(
          [encode_tls_vector(MTLText, 4),
           encode_tls_vector(
             encode_tls_vector(
               list_to_binary(
                 [encode_tls_vector(C, 3) || C <- ExtraData]),
               3),
             4)]),
    ok = plop:add(LogEntry, MTLHash, EntryHash),
    {TSE, MTLHash}.

get_ratelimit_token(Type) ->
    ratelimit:get_token(Type).

-spec add_chain(binary(), [binary()], normal|precert) -> {[{_,_},...]}.
add_chain(LeafCert, CertChain, Type) ->
    EntryHash = crypto:hash(sha256, [LeafCert | CertChain]),
    {TimestampedEntry, Hash} =
        case plop:get(EntryHash) of
            notfound ->
                case get_ratelimit_token(add_chain) of
                    ok ->
                        add_to_db(Type, LeafCert, CertChain, EntryHash);
                    _ ->
                        exit({internalerror, "Rate limiting"})
                end;
            {_Index, MTLHash, DBEntry} ->
                {MTLText, _ExtraData} = unpack_entry(DBEntry),
                MTL = deserialise_mtl(MTLText),
                MTLText = serialise(MTL),       % verify FIXME: remove
                {MTL#mtl.entry, MTLHash}
        end,

    SCT_sig = get_sct(Hash, TimestampedEntry),
    {[{sct_version, ?PROTOCOL_VERSION},
      {id, base64:encode(plop:get_logid())},
      {timestamp, TimestampedEntry#timestamped_entry.timestamp},
      {extensions, base64:encode(<<>>)},
      {signature, base64:encode(SCT_sig)}]}.

-spec timestamped_entry(integer(), entry_type(), binary(), binary()) ->
                               timestamped_entry().
timestamped_entry(Timestamp, EntryType, LeafCert, CertChain) ->
    SignedEntry =
        case EntryType of
            x509_entry ->
                #signed_x509_entry{asn1_cert = LeafCert};
            precert_entry ->
                {DetoxedLeafTBSCert, IssuerKeyHash} =
                    x509:detox(LeafCert, CertChain),
                #signed_precert_entry{
                   issuer_key_hash = IssuerKeyHash,
                   tbs_certificate = DetoxedLeafTBSCert}
        end,
    #timestamped_entry{timestamp = Timestamp,
                       entry_type = EntryType,
                       signed_entry = SignedEntry}.

-spec deserialise_mtl(binary()) -> mtl().
deserialise_mtl(Data) ->
    <<LeafVersionBin:1/binary,
      LeafTypeBin:1/binary,
      TimestampedEntryBin/binary>> = Data,
    #mtl{leaf_version = deserialise_leaf_version(LeafVersionBin),
         leaf_type = deserialise_leaf_type(LeafTypeBin),
         entry = deserialise_timestampedentry(TimestampedEntryBin)}.

-spec deserialise_timestampedentry(binary()) -> timestamped_entry().
deserialise_timestampedentry(Data) ->
    <<Timestamp:64, EntryTypeBin:2/binary, RestData/binary>> = Data,
    EntryType = deserialise_entry_type(EntryTypeBin),
    {SignedEntry, ExtensionsBin} =
        case EntryType of
            x509_entry ->
                deserialise_signed_x509_entry(RestData);
            precert_entry ->
                deserialise_signed_precert_entry(RestData)
        end,
    {Extensions, <<>>} = decode_tls_vector(ExtensionsBin, 2),
    #timestamped_entry{timestamp = Timestamp,
                       entry_type = EntryType,
                       signed_entry = SignedEntry,
                       extensions = Extensions}.

-spec deserialise_signed_x509_entry(binary()) -> {signed_x509_entry(), binary()}.
deserialise_signed_x509_entry(Data) ->
    {E, D} = decode_tls_vector(Data, 3),
    {#signed_x509_entry{asn1_cert = E}, D}.

-spec deserialise_signed_precert_entry(binary()) ->
                                              {signed_precert_entry(), binary()}.
deserialise_signed_precert_entry(Data) ->
    <<IssuerKeyHash:32/binary, RestData/binary>> = Data,
    {TBSCertificate, RestData2} = decode_tls_vector(RestData, 3),
    {#signed_precert_entry{issuer_key_hash = IssuerKeyHash,
                           tbs_certificate = TBSCertificate},
     RestData2}.    

-spec entries(non_neg_integer(), non_neg_integer()) -> {[{entries, list()},...]}.
entries(Start, End) ->
    {[{entries, x_entries(plop:get(Start, End))}]}.

-spec entry_and_proof(non_neg_integer(), non_neg_integer()) -> {[{_,_},...]}.
entry_and_proof(Index, TreeSize) ->
    case plop:inclusion_and_entry(Index, TreeSize) of
        {ok, Entry, Path} ->
            {MTL, ExtraData} = unpack_entry(Entry),
            {[{leaf_input, base64:encode(MTL)},
              {extra_data, base64:encode(ExtraData)},
              {audit_path, [base64:encode(X) || X <- Path]}]};
        {notfound, Msg} ->
            {[{success, false},
              {error_message, list_to_binary(Msg)}]}
    end.

-define(CACHE_TABLE, catlfish_cache).
init_cache_table() ->
    case ets:info(?CACHE_TABLE) of
	undefined -> ok;
	_ -> ets:delete(?CACHE_TABLE)
    end,
    ets:new(?CACHE_TABLE, [set, public, named_table]).

deserialise_extra_data(<<>>) ->
    [];
deserialise_extra_data(ExtraData) ->
    {E, Rest} = decode_tls_vector(ExtraData, 3),
    [E | deserialise_extra_data(Rest)].

chain_from_mtl_extradata(MTL, ExtraData) ->
    TimestampedEntry = MTL#mtl.entry,
    Chain = deserialise_extra_data(ExtraData),
    case TimestampedEntry#timestamped_entry.entry_type of
        x509_entry ->
            SignedEntry = TimestampedEntry#timestamped_entry.signed_entry,
            [SignedEntry#signed_x509_entry.asn1_cert | Chain];
        precert_entry ->
            Chain
    end.

mtl_and_extra_from_entry(Entry) ->
    {MTLText, ExtraDataPacked} = unpack_entry(Entry),
    {ExtraData, <<>>} = decode_tls_vector(ExtraDataPacked, 3),
    MTL = deserialise_mtl(MTLText),
    {MTL, ExtraData}.

verify_mtl(MTL, LeafCert, CertChain) ->
    Timestamp = MTL#mtl.entry#timestamped_entry.timestamp,
    EntryType = MTL#mtl.entry#timestamped_entry.entry_type,
    TSE = timestamped_entry(Timestamp, EntryType, LeafCert, CertChain),
    case MTL of
        #mtl{leaf_version = v1,
             leaf_type = timestamped_entry,
             entry = TSE} ->
            ok;
        _ ->
            error
    end.

verify_entry(Entry) ->
    RootCerts = known_roots(),
    verify_entry(Entry, RootCerts).

verify_entry(Entry, RootCerts) ->
    {MTL, ExtraData} = mtl_and_extra_from_entry(Entry),
    Chain = chain_from_mtl_extradata(MTL, ExtraData),

    case x509:normalise_chain(RootCerts, Chain) of
        {ok, [LeafCert|CertChain]} ->
            case verify_mtl(MTL, LeafCert, CertChain) of
                ok ->
                    {ok, ht:leaf_hash(serialise(MTL))};
                error ->
                    {error, "MTL verification failed"}
            end;
        {error, Reason} ->
            {error, Reason}
    end.

entryhash_from_entry(Entry) ->
    {MTL, ExtraData} = mtl_and_extra_from_entry(Entry),
    Chain = chain_from_mtl_extradata(MTL, ExtraData),
    crypto:hash(sha256, Chain).

%% Private functions.
-spec unpack_entry(binary()) -> {binary(), binary()}.
unpack_entry(Entry) ->
    {MTL, Rest} = decode_tls_vector(Entry, 4),
    {ExtraData, <<>>} = decode_tls_vector(Rest, 4),
    {MTL, ExtraData}.

-spec x_entries([{non_neg_integer(), binary(), binary()}]) -> list().
x_entries([]) ->
    [];
x_entries([H|T]) ->
    {_Index, _Hash, Entry} = H,
    {MTL, ExtraData} = unpack_entry(Entry),
    [{[{leaf_input, base64:encode(MTL)},
       {extra_data, base64:encode(ExtraData)}]} | x_entries(T)].

-spec encode_tls_vector(binary(), non_neg_integer()) -> binary().
encode_tls_vector(Binary, LengthLen) ->
    Length = byte_size(Binary),
    <<Length:LengthLen/integer-unit:8, Binary/binary>>.

-spec decode_tls_vector(binary(), non_neg_integer()) -> {binary(), binary()}.
decode_tls_vector(Binary, LengthLen) ->
    <<Length:LengthLen/integer-unit:8, Rest/binary>> = Binary,
    <<ExtractedBinary:Length/binary-unit:8, Rest2/binary>> = Rest,
    {ExtractedBinary, Rest2}.

-define(ROOTS_CACHE_KEY, roots).

update_known_roots() ->
    case application:get_env(catlfish, known_roots_path) of
        {ok, Dir} -> update_known_roots(Dir);
        undefined -> []
    end.

update_known_roots(Directory) ->
    known_roots(Directory, update_tab).

known_roots() ->
    case application:get_env(catlfish, known_roots_path) of
        {ok, Dir} -> known_roots(Dir, use_cache);
        undefined -> []
    end.

-spec known_roots(file:filename(), use_cache|update_tab) -> [binary()].
known_roots(Directory, CacheUsage) ->
    case CacheUsage of
        use_cache ->
            case ets:lookup(?CACHE_TABLE, ?ROOTS_CACHE_KEY) of
                [] ->
                    read_files_and_update_table(Directory);
                [{roots, DerList}] ->
                    DerList
            end;
        update_tab ->
            read_files_and_update_table(Directory)
    end.

read_files_and_update_table(Directory) ->
    Certs = x509:read_pemfiles_from_dir(Directory),
    Proper = x509:self_signed(Certs),
    case length(Certs) - length(Proper) of
        0 -> ok;
        N -> lager:warning(
               "Ignoring ~p root certificates not signing themselves properly",
               [N])
    end,
    true = ets:insert(?CACHE_TABLE, {?ROOTS_CACHE_KEY, Proper}),
    lager:info("Known roots imported: ~p", [length(Proper)]),
    Proper.

%%%%%%%%%%%%%%%%%%%%
%% Testing internal functions.
-define(PEMFILES_DIR_OK, "test/testdata/known_roots").
-define(PEMFILES_DIR_NONEXISTENT, "test/testdata/nonexistent-dir").

read_pemfiles_test_() ->
    {setup,
     fun() ->
             init_cache_table(),
             {known_roots(?PEMFILES_DIR_OK, update_tab),
              known_roots(?PEMFILES_DIR_OK, use_cache)}
     end,
     fun(_) -> ets:delete(?CACHE_TABLE, ?ROOTS_CACHE_KEY) end,
     fun({L, LCached}) ->
             [?_assertMatch(4, length(L)),
              ?_assertEqual(L, LCached)]
     end}.

read_pemfiles_fail_test_() ->
    {setup,
     fun() ->
             init_cache_table(),
             known_roots(?PEMFILES_DIR_NONEXISTENT, update_tab)
     end,
     fun(_) -> ets:delete(?CACHE_TABLE, ?ROOTS_CACHE_KEY) end,
     fun(Empty) -> [?_assertMatch([], Empty)] end}.