{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Ouroboros.Consensus.Protocol.Praos.AgentClient
  ( AgentCrypto (..)
  , KESAgentClientTrace (..)
  , KESAgentContext
  , MonadKESAgent (..)
  , runKESAgentClient
  ) where

import Cardano.Crypto.DirectSerialise
  ( DirectDeserialise
  , DirectSerialise
  )
import Cardano.Crypto.KES.Class
import Cardano.Crypto.VRF.Class
import qualified Cardano.KESAgent.KES.Bundle as Agent
import qualified Cardano.KESAgent.KES.Crypto as Agent
import qualified Cardano.KESAgent.KES.OCert as Agent
import qualified Cardano.KESAgent.Processes.ServiceClient as Agent
import qualified Cardano.KESAgent.Protocols.RecvResult as Agent
import qualified Cardano.KESAgent.Protocols.StandardCrypto as Agent
import qualified Cardano.KESAgent.Protocols.VersionedProtocol as Agent
import Cardano.KESAgent.Util.RefCounting
import Cardano.Ledger.Keys (DSIGN)
import Cardano.Protocol.Crypto (Crypto, KES, StandardCrypto, VRF)
import qualified Cardano.Protocol.TPraos.OCert as OCert
import Control.Monad (forever)
import Control.Monad.Class.MonadAsync
import Control.Monad.IOSim
import Control.Tracer
import Data.Coerce (coerce)
import Data.Kind
import Data.Typeable
import Network.Socket
import Ouroboros.Consensus.Util.IOLike
import Ouroboros.Network.RawBearer
import Ouroboros.Network.Snocket
import qualified Simulation.Network.Snocket as SimSnocket
import System.IOManager
import Test.Ouroboros.Network.Data.AbsBearerInfo as ABI

type KESAgentContext c m =
  ( AgentCrypto c
  , MonadKESAgent m
  , IOLike m
  )

data KESAgentClientTrace
  = KESAgentClientException SomeException
  | KESAgentClientTrace Agent.ServiceClientTrace
  deriving Int -> KESAgentClientTrace -> ShowS
[KESAgentClientTrace] -> ShowS
KESAgentClientTrace -> String
(Int -> KESAgentClientTrace -> ShowS)
-> (KESAgentClientTrace -> String)
-> ([KESAgentClientTrace] -> ShowS)
-> Show KESAgentClientTrace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KESAgentClientTrace -> ShowS
showsPrec :: Int -> KESAgentClientTrace -> ShowS
$cshow :: KESAgentClientTrace -> String
show :: KESAgentClientTrace -> String
$cshowList :: [KESAgentClientTrace] -> ShowS
showList :: [KESAgentClientTrace] -> ShowS
Show

class
  ( Crypto c
  , Agent.Crypto (ACrypto c)
  , Agent.NamedCrypto (ACrypto c)
  , Agent.KES (ACrypto c) ~ KES c
  , ContextKES (KES c) ~ ()
  , ContextVRF (VRF c) ~ ()
  , Typeable (ACrypto c)
  , Agent.ServiceClientDrivers (ACrypto c)
  , DirectSerialise (SignKeyKES (KES c))
  , DirectDeserialise (SignKeyKES (KES c))
  ) =>
  AgentCrypto c
  where
  type ACrypto c :: Type

instance AgentCrypto StandardCrypto where
  type ACrypto StandardCrypto = Agent.StandardCrypto

convertOCert ::
  (AgentCrypto c, Agent.DSIGN (ACrypto c) ~ DSIGN) => Agent.OCert (ACrypto c) -> OCert.OCert c
convertOCert :: forall c.
(AgentCrypto c, DSIGN (ACrypto c) ~ Ed25519DSIGN) =>
OCert (ACrypto c) -> OCert c
convertOCert OCert (ACrypto c)
oca =
  OCert.OCert
    { ocertVkHot :: VerKeyKES (KES c)
OCert.ocertVkHot = OCert (ACrypto c) -> VerKeyKES (KES (ACrypto c))
forall c. OCert c -> VerKeyKES (KES c)
Agent.ocertVkHot OCert (ACrypto c)
oca
    , ocertN :: Word64
OCert.ocertN = OCert (ACrypto c) -> Word64
forall c. OCert c -> Word64
Agent.ocertN OCert (ACrypto c)
oca
    , ocertKESPeriod :: KESPeriod
OCert.ocertKESPeriod = Word -> KESPeriod
OCert.KESPeriod (KESPeriod -> Word
Agent.unKESPeriod (KESPeriod -> Word) -> KESPeriod -> Word
forall a b. (a -> b) -> a -> b
$ OCert (ACrypto c) -> KESPeriod
forall c. OCert c -> KESPeriod
Agent.ocertKESPeriod OCert (ACrypto c)
oca)
    , ocertSigma :: SignedDSIGN Ed25519DSIGN (OCertSignable c)
OCert.ocertSigma = SignedDSIGN (DSIGN (ACrypto c)) (OCertSignable (ACrypto c))
-> SignedDSIGN Ed25519DSIGN (OCertSignable c)
forall a b. Coercible a b => a -> b
coerce (OCert (ACrypto c)
-> SignedDSIGN (DSIGN (ACrypto c)) (OCertSignable (ACrypto c))
forall c. OCert c -> SignedDSIGN (DSIGN c) (OCertSignable c)
Agent.ocertSigma OCert (ACrypto c)
oca)
    }

convertPeriod :: Agent.KESPeriod -> OCert.KESPeriod
convertPeriod :: KESPeriod -> KESPeriod
convertPeriod (Agent.KESPeriod Word
p) = Word -> KESPeriod
OCert.KESPeriod Word
p

class (MonadFail m, Show (Addr m)) => MonadKESAgent m where
  type FD m :: Type
  type Addr m :: Type
  withAgentContext :: (Snocket m (FD m) (Addr m) -> m a) -> m a
  makeRawBearer :: MakeRawBearer m (FD m)
  makeAddress :: Proxy m -> FilePath -> Addr m

instance MonadKESAgent IO where
  type FD IO = Socket
  type Addr IO = SockAddr
  withAgentContext :: forall a. (Snocket IO (FD IO) (Addr IO) -> IO a) -> IO a
withAgentContext Snocket IO (FD IO) (Addr IO) -> IO a
inner =
    (IOManager -> IO a) -> IO a
WithIOManager
withIOManager ((IOManager -> IO a) -> IO a) -> (IOManager -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \IOManager
ioManager ->
      Snocket IO (FD IO) (Addr IO) -> IO a
inner (IOManager -> SocketSnocket
socketSnocket IOManager
ioManager)
  makeRawBearer :: MakeRawBearer IO (FD IO)
makeRawBearer = MakeRawBearer IO Socket
MakeRawBearer IO (FD IO)
makeSocketRawBearer
  makeAddress :: Proxy IO -> String -> Addr IO
makeAddress Proxy IO
_ = String -> SockAddr
String -> Addr IO
SockAddrUnix

instance MonadKESAgent (IOSim s) where
  type FD (IOSim s) = SimSnocket.FD (IOSim s) (TestAddress FilePath)
  type Addr (IOSim s) = TestAddress FilePath
  withAgentContext :: forall a.
(Snocket (IOSim s) (FD (IOSim s)) (Addr (IOSim s)) -> IOSim s a)
-> IOSim s a
withAgentContext Snocket (IOSim s) (FD (IOSim s)) (Addr (IOSim s)) -> IOSim s a
inner = do
    Tracer
  (IOSim s)
  (WithAddr
     (TestAddress String) (SnocketTrace (IOSim s) (TestAddress String)))
-> BearerInfo
-> Map (NormalisedId (TestAddress String)) (Script BearerInfo)
-> (Snocket
      (IOSim s) (FD (IOSim s) (TestAddress String)) (TestAddress String)
    -> IOSim s (ObservableNetworkState (TestAddress String))
    -> IOSim s a)
-> IOSim s a
forall (m :: * -> *) peerAddr a.
(Alternative (STM m), MonadDelay m, MonadLabelledSTM m,
 MonadMask m, MonadTimer m, MonadThrow (STM m),
 GlobalAddressScheme peerAddr, Ord peerAddr, Typeable peerAddr,
 Show peerAddr) =>
Tracer
  m
  (WithAddr
     (TestAddress peerAddr) (SnocketTrace m (TestAddress peerAddr)))
-> BearerInfo
-> Map (NormalisedId (TestAddress peerAddr)) (Script BearerInfo)
-> (Snocket m (FD m (TestAddress peerAddr)) (TestAddress peerAddr)
    -> m (ObservableNetworkState (TestAddress peerAddr)) -> m a)
-> m a
SimSnocket.withSnocket
      Tracer
  (IOSim s)
  (WithAddr
     (TestAddress String) (SnocketTrace (IOSim s) (TestAddress String)))
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
      (AbsBearerInfo -> BearerInfo
toBearerInfo (AbsBearerInfo -> BearerInfo) -> AbsBearerInfo -> BearerInfo
forall a b. (a -> b) -> a -> b
$ AbsBearerInfo
absNoAttenuation{abiConnectionDelay = SmallDelay})
      Map (NormalisedId (TestAddress String)) (Script BearerInfo)
forall a. Monoid a => a
mempty
      ((Snocket
    (IOSim s) (FD (IOSim s) (TestAddress String)) (TestAddress String)
  -> IOSim s (ObservableNetworkState (TestAddress String))
  -> IOSim s a)
 -> IOSim s a)
-> (Snocket
      (IOSim s) (FD (IOSim s) (TestAddress String)) (TestAddress String)
    -> IOSim s (ObservableNetworkState (TestAddress String))
    -> IOSim s a)
-> IOSim s a
forall a b. (a -> b) -> a -> b
$ \Snocket
  (IOSim s) (FD (IOSim s) (TestAddress String)) (TestAddress String)
snocket IOSim s (ObservableNetworkState (TestAddress String))
_observe -> Snocket (IOSim s) (FD (IOSim s)) (Addr (IOSim s)) -> IOSim s a
inner Snocket
  (IOSim s) (FD (IOSim s) (TestAddress String)) (TestAddress String)
Snocket (IOSim s) (FD (IOSim s)) (Addr (IOSim s))
snocket
  makeRawBearer :: MakeRawBearer (IOSim s) (FD (IOSim s))
makeRawBearer = Tracer (IOSim s) FDRawBearerTrace
-> MakeRawBearer (IOSim s) (FD (IOSim s) (TestAddress String))
forall (m :: * -> *) addr.
(MonadST m, MonadThrow m, MonadLabelledSTM m, Show addr) =>
Tracer m FDRawBearerTrace
-> MakeRawBearer m (FD m (TestAddress addr))
SimSnocket.makeFDRawBearer Tracer (IOSim s) FDRawBearerTrace
forall (m :: * -> *) a. Applicative m => Tracer m a
nullTracer
  makeAddress :: Proxy (IOSim s) -> String -> Addr (IOSim s)
makeAddress Proxy (IOSim s)
_ = String -> TestAddress String
String -> Addr (IOSim s)
forall addr. addr -> TestAddress addr
TestAddress

instance SimSnocket.GlobalAddressScheme FilePath where
  getAddressType :: TestAddress String -> AddressType
getAddressType = AddressType -> TestAddress String -> AddressType
forall a b. a -> b -> a
const AddressType
SimSnocket.IPv4Address
  ephemeralAddress :: AddressType -> Nat -> TestAddress String
ephemeralAddress AddressType
_ty Nat
num = String -> TestAddress String
forall addr. addr -> TestAddress addr
TestAddress (String -> TestAddress String) -> String -> TestAddress String
forall a b. (a -> b) -> a -> b
$ String
"simSnocket_" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Nat -> String
forall a. Show a => a -> String
show Nat
num

runKESAgentClient ::
  forall m c.
  ( KESAgentContext c m
  , Agent.DSIGN (ACrypto c) ~ DSIGN
  ) =>
  Tracer m KESAgentClientTrace ->
  FilePath ->
  (OCert.OCert c -> SignKeyKES (KES c) -> Word -> OCert.KESPeriod -> m ()) ->
  m () ->
  m ()
runKESAgentClient :: forall (m :: * -> *) c.
(KESAgentContext c m, DSIGN (ACrypto c) ~ Ed25519DSIGN) =>
Tracer m KESAgentClientTrace
-> String
-> (OCert c -> SignKeyKES (KES c) -> Word -> KESPeriod -> m ())
-> m ()
-> m ()
runKESAgentClient Tracer m KESAgentClientTrace
tracer String
path OCert c -> SignKeyKES (KES c) -> Word -> KESPeriod -> m ()
handleKey m ()
handleDropKey = do
  (Snocket m (FD m) (Addr m) -> m ()) -> m ()
forall a. (Snocket m (FD m) (Addr m) -> m a) -> m a
forall (m :: * -> *) a.
MonadKESAgent m =>
(Snocket m (FD m) (Addr m) -> m a) -> m a
withAgentContext ((Snocket m (FD m) (Addr m) -> m ()) -> m ())
-> (Snocket m (FD m) (Addr m) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Snocket m (FD m) (Addr m)
snocket -> do
    m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Proxy (ACrypto c)
-> MakeRawBearer m (FD m)
-> ServiceClientOptions m (FD m) (Addr m)
-> (TaggedBundle m (ACrypto c) -> m RecvResult)
-> Tracer m ServiceClientTrace
-> m ()
forall c (m :: * -> *) fd addr.
(ServiceClientContext m c, Show addr) =>
Proxy c
-> MakeRawBearer m fd
-> ServiceClientOptions m fd addr
-> (TaggedBundle m c -> m RecvResult)
-> Tracer m ServiceClientTrace
-> m ()
Agent.runServiceClient
        (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @(ACrypto c))
        MakeRawBearer m (FD m)
forall (m :: * -> *). MonadKESAgent m => MakeRawBearer m (FD m)
makeRawBearer
        ( Agent.ServiceClientOptions
            { serviceClientSnocket :: Snocket m (FD m) (Addr m)
Agent.serviceClientSnocket = Snocket m (FD m) (Addr m)
snocket
            , serviceClientAddress :: Addr m
Agent.serviceClientAddress = Proxy m -> String -> Addr m
forall (m :: * -> *).
MonadKESAgent m =>
Proxy m -> String -> Addr m
makeAddress (forall {k} (t :: k). Proxy t
forall (t :: * -> *). Proxy t
Proxy @m) String
path
            } ::
            Agent.ServiceClientOptions m (FD m) (Addr m)
        )
        ( \(Agent.TaggedBundle Maybe (Bundle m (ACrypto c))
mBundle Timestamp
_) -> do
            case Maybe (Bundle m (ACrypto c))
mBundle of
              Just (Agent.Bundle CRef m (SignKeyWithPeriodKES (KES (ACrypto c)))
skpRef OCert (ACrypto c)
ocert) -> do
                -- We take ownership of the key, so we acquire one extra reference,
                -- preventing the key from being discarded after `handleKey`
                -- finishes.
                _ <- CRef m (SignKeyWithPeriodKES (KES c))
-> m (SignKeyWithPeriodKES (KES c))
forall (m :: * -> *) a.
(MonadSTM m, MonadThrow m) =>
CRef m a -> m a
acquireCRef CRef m (SignKeyWithPeriodKES (KES c))
CRef m (SignKeyWithPeriodKES (KES (ACrypto c)))
skpRef
                withCRefValue skpRef $ \(SignKeyWithPeriodKES SignKeyKES (KES c)
sk Word
p) ->
                  OCert c -> SignKeyKES (KES c) -> Word -> KESPeriod -> m ()
handleKey (OCert (ACrypto c) -> OCert c
forall c.
(AgentCrypto c, DSIGN (ACrypto c) ~ Ed25519DSIGN) =>
OCert (ACrypto c) -> OCert c
convertOCert OCert (ACrypto c)
ocert) SignKeyKES (KES c)
sk Word
p (KESPeriod -> KESPeriod
convertPeriod (KESPeriod -> KESPeriod) -> KESPeriod -> KESPeriod
forall a b. (a -> b) -> a -> b
$ OCert (ACrypto c) -> KESPeriod
forall c. OCert c -> KESPeriod
Agent.ocertKESPeriod OCert (ACrypto c)
ocert)
                return Agent.RecvOK
              Maybe (Bundle m (ACrypto c))
_ -> do
                m ()
handleDropKey
                RecvResult -> m RecvResult
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return RecvResult
Agent.RecvOK
        )
        ((ServiceClientTrace -> KESAgentClientTrace)
-> Tracer m KESAgentClientTrace -> Tracer m ServiceClientTrace
forall a' a. (a' -> a) -> Tracer m a -> Tracer m a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap ServiceClientTrace -> KESAgentClientTrace
KESAgentClientTrace Tracer m KESAgentClientTrace
tracer)
        m () -> (AsyncCancelled -> m ()) -> m ()
forall e a. Exception e => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` ( \(AsyncCancelled
_e :: AsyncCancelled) ->
                    () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                )
        m () -> (SomeException -> m ()) -> m ()
forall e a. Exception e => m a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` ( \(SomeException
e :: SomeException) ->
                    Tracer m KESAgentClientTrace -> KESAgentClientTrace -> m ()
forall (m :: * -> *) a. Tracer m a -> a -> m ()
traceWith Tracer m KESAgentClientTrace
tracer (SomeException -> KESAgentClientTrace
KESAgentClientException SomeException
e)
                )
      DiffTime -> m ()
forall (m :: * -> *). MonadDelay m => DiffTime -> m ()
threadDelay DiffTime
10000000

toBearerInfo :: ABI.AbsBearerInfo -> SimSnocket.BearerInfo
toBearerInfo :: AbsBearerInfo -> BearerInfo
toBearerInfo AbsBearerInfo
abi =
  SimSnocket.BearerInfo
    { biConnectionDelay :: DiffTime
SimSnocket.biConnectionDelay = AbsDelay -> DiffTime
ABI.delay (AbsBearerInfo -> AbsDelay
ABI.abiConnectionDelay AbsBearerInfo
abi)
    , biInboundAttenuation :: Time -> Size -> (DiffTime, SuccessOrFailure)
SimSnocket.biInboundAttenuation = AbsAttenuation -> Time -> Size -> (DiffTime, SuccessOrFailure)
attenuation (AbsBearerInfo -> AbsAttenuation
ABI.abiInboundAttenuation AbsBearerInfo
abi)
    , biOutboundAttenuation :: Time -> Size -> (DiffTime, SuccessOrFailure)
SimSnocket.biOutboundAttenuation = AbsAttenuation -> Time -> Size -> (DiffTime, SuccessOrFailure)
attenuation (AbsBearerInfo -> AbsAttenuation
ABI.abiOutboundAttenuation AbsBearerInfo
abi)
    , biInboundWriteFailure :: Maybe Int
SimSnocket.biInboundWriteFailure = AbsBearerInfo -> Maybe Int
ABI.abiInboundWriteFailure AbsBearerInfo
abi
    , biOutboundWriteFailure :: Maybe Int
SimSnocket.biOutboundWriteFailure = AbsBearerInfo -> Maybe Int
ABI.abiOutboundWriteFailure AbsBearerInfo
abi
    , biAcceptFailures :: Maybe (DiffTime, IOError)
SimSnocket.biAcceptFailures =
        ( \(AbsDelay
errDelay, IOError
errType) ->
            ( AbsDelay -> DiffTime
ABI.delay AbsDelay
errDelay
            , IOError
errType
            )
        )
          ((AbsDelay, IOError) -> (DiffTime, IOError))
-> Maybe (AbsDelay, IOError) -> Maybe (DiffTime, IOError)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AbsBearerInfo -> Maybe (AbsDelay, IOError)
abiAcceptFailure AbsBearerInfo
abi
    , biSDUSize :: SDUSize
SimSnocket.biSDUSize = AbsSDUSize -> SDUSize
toSduSize (AbsBearerInfo -> AbsSDUSize
ABI.abiSDUSize AbsBearerInfo
abi)
    }