%%% Copyright (c) 2014, NORDUnet A/S.
%%% See LICENSE for licensing information.

-module(x509).
-export([normalise_chain/2, cert_string/1, read_pemfiles_from_dir/1]).

-include_lib("public_key/include/public_key.hrl").
-include_lib("eunit/include/eunit.hrl").

-type reason() :: {chain_too_long |
                   root_unknown |
                   signature_mismatch |
                   encoding_invalid}.

-define(MAX_CHAIN_LENGTH, 10).

-spec normalise_chain([binary()], [binary()]) -> {ok, [binary()]} |
                                                 {error, reason()}.
normalise_chain(AcceptableRootCerts, CertChain) ->
    case valid_chain_p(AcceptableRootCerts, CertChain, ?MAX_CHAIN_LENGTH) of
        {false, Reason} ->
            {error, Reason};
        {true, Root} ->
            [Leaf | Chain] = CertChain,
            {ok, [detox_precert(Leaf) | Chain] ++ Root}
    end.

%%%%%%%%%%%%%%%%%%%%
%% @doc Verify that the leaf cert or precert has a valid chain back to
%% an acceptable root cert. Order of certificates in second argument
%% is: leaf cert in head, chain in tail. Order of first argument is
%% irrelevant.
-spec valid_chain_p([binary()], [binary()], integer()) ->
                           {false, reason()} | {true, list()}.
valid_chain_p(_, _, MaxChainLength) when MaxChainLength =< 0 ->
    %% Chain too long.
    {false, chain_too_long};
valid_chain_p(AcceptableRootCerts, [TopCert], MaxChainLength) ->
    %% Check root of chain.
    case lists:member(TopCert, AcceptableRootCerts) of
        true ->
            %% Top cert is part of chain.
            {true, []};
        false when MaxChainLength =< 1 ->
            %% Chain too long.
            {false, chain_too_long};
        false ->
            %% Top cert _might_ be signed by a cert in truststore.
            case signer(TopCert, AcceptableRootCerts) of
                notfound -> {false, root_unknown};
                Root -> {true, [Root]}
            end
    end;
valid_chain_p(AcceptableRootCerts, [BottomCert|Rest], MaxChainLength) ->
    case signed_by_p(BottomCert, hd(Rest)) of
        true -> valid_chain_p(AcceptableRootCerts, Rest, MaxChainLength - 1);
        Err -> Err
    end.

%% @doc Return first cert in list signing Cert, or notfound. NOTE:
%% This is potentially expensive. It'd be more efficient to search for
%% Cert.issuer in a list of Issuer.subject's. If so, maybe make the
%% matching somewhat fuzzy unless that too is expensive.
-spec signer(binary(), [binary()]) -> notfound | binary().
signer(_Cert, []) ->
    notfound;
signer(Cert, [H|T]) ->
    case signed_by_p(Cert, H) of
        true -> H;
        {false, _} -> signer(Cert, T)
    end.

%% Code from pubkey_cert:encoded_tbs_cert/1.
encoded_tbs_cert(DerCert) ->
    {ok, PKIXCert} =
	'OTP-PUB-KEY':decode_TBSCert_exclusive(DerCert),
    {'Certificate', {'Certificate_tbsCertificate', EncodedTBSCert}, _, _} =
        PKIXCert,
    EncodedTBSCert.

%% Code from pubkey_cert:extract_verify_data/2.
verifydata_from_cert(Cert, DerCert) ->
    PlainText = encoded_tbs_cert(DerCert),
    {_, Sig} = Cert#'Certificate'.signature,
    SigAlgRecord = Cert#'Certificate'.signatureAlgorithm,
    SigAlg = SigAlgRecord#'AlgorithmIdentifier'.algorithm,
    {DigestType,_} = public_key:pkix_sign_types(SigAlg),
    {PlainText, DigestType, Sig}.

%% @doc Verify that Cert/DerCert is signed by Issuer.
verify(Cert, DerCert,                         % Certificate to verify.
       #'Certificate'{                        % Issuer.
          tbsCertificate = #'TBSCertificate'{
                              subjectPublicKeyInfo = IssuerSPKI}}) ->

    %% Dig out digest, digest type and signature from Cert/DerCert.
    {DigestOrPlainText, DigestType, Signature} = verifydata_from_cert(Cert,
                                                                      DerCert),
    %% Dig out issuer key from issuer cert.
    #'SubjectPublicKeyInfo'{
       algorithm = #'AlgorithmIdentifier'{algorithm = Alg, parameters = Params},
       subjectPublicKey = {0, Key0}} = IssuerSPKI,
    KeyType = pubkey_cert_records:supportedPublicKeyAlgorithms(Alg),
    IssuerKey =
        case KeyType of
            'RSAPublicKey' ->
                public_key:der_decode(KeyType, Key0);
            'DSAPublicKey' ->
                {params, DssParams} = public_key:der_decode('DSAParams', Params),
                {public_key:der_decode(KeyType, Key0), DssParams};
            'ECPoint' ->
                public_key:der_decode(KeyType, Key0)
        end,

    %% Verify the signature.
    public_key:verify(DigestOrPlainText, DigestType, Signature, IssuerKey).

%% @doc Is Cert signed by Issuer? Only verify that the signature
%% matches and don't check things like Cert.issuer == Issuer.subject.
-spec signed_by_p(binary(), binary()) -> true | {false, reason()}.
signed_by_p(DerCert, IssuerDerCert) when is_binary(DerCert),
                                         is_binary(IssuerDerCert) ->
    case verify(public_key:pkix_decode_cert(DerCert, plain),
                DerCert,
                public_key:pkix_decode_cert(IssuerDerCert, plain)) of
        false -> {false, signature_mismatch};
        true -> true
    end.

-spec public_key(binary() | #'OTPCertificate'{}) -> public_key:public_key().
public_key(CertDer) when is_binary(CertDer) ->
    public_key(public_key:pkix_decode_cert(CertDer, otp));
public_key(#'OTPCertificate'{
              tbsCertificate =
                  #'OTPTBSCertificate'{subjectPublicKeyInfo =
                                           #'OTPSubjectPublicKeyInfo'{
                                              subjectPublicKey = Key}}}) ->
    Key.

cert_string(Der) ->
    mochihex:to_hex(crypto:hash(sha, Der)).

parsable_cert_p(Der) ->
    case (catch public_key:pkix_decode_cert(Der, plain)) of
        #'Certificate'{} ->
            true;
        {'EXIT', Reason} ->
            lager:info("invalid certificate: ~p: ~p", [cert_string(Der), Reason]),
            false;
        Unknown ->
            lager:info("unknown error decoding cert: ~p: ~p",
                       [cert_string(Der), Unknown]),
            false
    end.

%%%%%%%%%%%%%%%%%%%%
%% Precertificates according to draft-ietf-trans-rfc6962-bis-04.

%% Submitted precerts have a special critical poison extension -- OID
%% 1.3.6.1.4.1.11129.2.4.3, whose extnValue OCTET STRING contains
%% ASN.1 NULL data (0x05 0x00).

%% They are signed with either the CA cert that will sign the final
%% cert or Precertificate Signing Certificate directly signed by the
%% CA cert that will sign the final cert. A Precertificate Signing
%% Certificate has CA:true and Extended Key Usage: Certificate
%% Transparency, OID 1.3.6.1.4.1.11129.2.4.4.

%% A PreCert in a SignedCertificateTimestamp does _not_ contain the
%% poison extension, nor a Precertificate Signing Certificate. This
%% means that we might have to 1) remove poison extensions in leaf
%% certs, 2) remove "poisoned signatures", 3) change issuer and
%% Authority Key Identifier of leaf certs.

-spec detox_precert([#'Certificate'{}]) -> [#'Certificate'{}].
detox_precert(CertChain) ->
    CertChain.                                  % NYI

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-spec read_pemfiles_from_dir(file:filename()) -> [binary()].
%% @doc Reading certificates from files. Flattening the result -- all
%% certs in all files are returned in a single list.
read_pemfiles_from_dir(Dir) ->
    case file:list_dir(Dir) of
        {error, enoent} ->
            lager:error("directory does not exist: ~p", [Dir]),
            [];
        {error, Reason} ->
            lager:error("unable to read directory ~p: ~p", [Dir, Reason]),
            [];
        {ok, Filenames} ->
            Files = lists:filter(
                      fun(F) ->
                              string:equal(".pem", filename:extension(F))
                      end,
                      Filenames),
            ders_from_pemfiles(Dir, Files)
    end.

ders_from_pemfiles(Dir, Filenames) ->
    lists:flatten(
      [ders_from_pemfile(filename:join(Dir, X)) || X <- Filenames]).

ders_from_pemfile(Filename) ->
    PemBins = pems_from_file(Filename),
    Pems = case (catch public_key:pem_decode(PemBins)) of
               {'EXIT', Reason} ->
                   lager:info("~p: invalid PEM-encoding: ~p", [Filename, Reason]),
                   [];
               P -> P
           end,
    [der_from_pem(X) || X <- Pems].

der_from_pem(Pem) ->
    case Pem of
        {_Type, Der, not_encrypted} ->
            case parsable_cert_p(Der) of
                true ->
                    Der;
                false ->
                    dump_unparsable_cert(Der),
                    []
            end;
        Fail ->
            lager:info("ignoring PEM-encoded data: ~p~n", [Fail]),
            []
    end.

-spec pems_from_file(file:filename()) -> binary().
pems_from_file(Filename) ->
    {ok, Pems} = file:read_file(Filename),
    Pems.

-spec dump_unparsable_cert(binary()) -> ok | {error, atom()} | not_logged.
dump_unparsable_cert(CertDer) ->
    case application:get_env(catlfish, rejected_certs_path) of
        {ok, Directory} ->
            {NowMegaSec, NowSec, NowMicroSec} = now(),
            Filename =
                filename:join(Directory,
                              io_lib:format("~p:~p.~p",
                                            [cert_string(CertDer),
                                             NowMegaSec * 1000 * 1000 + NowSec,
                                             NowMicroSec])),
            lager:debug("dumping cert to ~p~n", [Filename]),
            file:write_file(Filename, CertDer);
        _ ->
            not_logged
    end.

%%%%%%%%%%%%%%%%%%%%
%% Testing private functions.
-include("x509_test.hrl").
sign_test_() ->
    {setup,
     fun() -> ok end,
     fun(_) -> ok end,
     fun(_) -> [?_assertMatch(true, signed_by_p(?C0, ?C1))] end}.

valid_cert_test_() ->
    {setup,
     fun() -> {read_pemfiles_from_dir("test/testdata/known_roots"),
               read_certs("test/testdata/chains")} end,
     fun(_) -> ok end,
     fun({KnownRoots, Chains}) ->
             [
              %% self-signed, not a valid OTPCertificate:
              %% {error,{asn1,{invalid_choice_tag,{22,<<"US">>}}}}
              %% 'OTP-PUB-KEY':Func('OTP-X520countryname', Value0)
              %% FIXME: this doesn't make much sense -- is my environment borked?
              ?_assertMatch({true, _},
                            valid_chain_p(lists:nth(1, Chains),
                                          lists:nth(1, Chains), 10)),
              %% self-signed
              ?_assertMatch({false, root_unknown},
                            valid_chain_p(KnownRoots,
                                          lists:nth(2, Chains), 10)),
              %% leaf signed by known CA
              ?_assertMatch({true, _},
                            valid_chain_p(KnownRoots,
                                          lists:nth(3, Chains), 10)),
              %% bug CATLFISH-19 --> [info] rejecting "3ee62cb678014c14d22ebf96f44cc899adea72f1": chain_broken
              %% leaf sha1: 3ee62cb678014c14d22ebf96f44cc899adea72f1
              %% leaf Subject: C=KR, O=Government of Korea, OU=Group of Server, OU=\xEA\xB5\x90\xEC\x9C\xA1\xEA\xB3\xBC\xED\x95\x99\xEA\xB8\xB0\xEC\x88\xA0\xEB\xB6\x80, CN=www.berea.ac.kr, CN=haksa.bits.ac.kr
              ?_assertMatch({true, _},
                            valid_chain_p(lists:nth(4, Chains),
                                          lists:nth(4, Chains), 10))
              ] end}.

chain_test_() ->
    {setup,
     fun() -> {?C0, ?C1} end,
     fun(_) -> ok end,
     fun({C0, C1}) -> chain_test(C0, C1) end}.

chain_test(C0, C1) ->
    [
     %% Root not in chain but in trust store.
     ?_assertEqual({true, [C1]}, valid_chain_p([C1], [C0], 10)),
     ?_assertEqual({true, [C1]}, valid_chain_p([C1], [C0], 2)),
     %% Chain too long.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C1], [C0], 1)),
     %% Root in chain and in trust store.
     ?_assertEqual({true, []}, valid_chain_p([C1], [C0, C1], 2)),
     %% Chain too long.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C1], [C0, C1], 1)),
     %% Root not in trust store.
     ?_assertMatch({false, root_unknown}, valid_chain_p([], [C0, C1], 10)),
     %% Selfsigned. Actually OK.
     ?_assertMatch({true, []}, valid_chain_p([C0], [C0], 10)),
     ?_assertMatch({true, []}, valid_chain_p([C0], [C0], 1)),
     %% Max chain length 0 is not OK.
     ?_assertMatch({false, chain_too_long}, valid_chain_p([C0], [C0], 0))
    ].

%%-spec read_certs(file:filename()) -> [string:string()].
-spec read_certs(file:filename()) -> [[binary()]].
read_certs(Dir) ->
    {ok, Fnames} = file:list_dir(Dir),
    PemBins =
        [Pems || {ok, Pems} <-
                     [file:read_file(filename:join(Dir, F)) ||
                         F <- lists:sort(
                                lists:filter(
                                  fun(FN) -> string:equal(
                                               ".pem", filename:extension(FN))
                                  end,
                                  Fnames))]],
    PemEntries = [public_key:pem_decode(P) || P <- PemBins],
    lists:map(fun(L) -> [Der || {'Certificate', Der, not_encrypted} <- L] end,
              PemEntries).