<?php

namespace RAP;

use phpseclib\Crypt\RSA;

/**
 * Manages the JWT Key Sets (currently only RSA).
 */
class JWKSHandler {

    private $locator;

    public function __construct(Locator $locator) {
        $this->locator = $locator;
    }

    public function generateKeyPair() {

        $rsa = new RSA();

        $rsa->setPrivateKeyFormat(RSA::PRIVATE_FORMAT_PKCS1);
        $rsa->setPublicKeyFormat(RSA::PUBLIC_FORMAT_PKCS8);
        // Guacamole needs a key of at least 2048
        $result = $rsa->createKey(2048);

        $keyPair = new RSAKeyPair();
        $keyPair->alg = 'RS256';
        $keyPair->privateKey = $result['privatekey'];
        $keyPair->publicKey = $result['publickey'];
        $keyPair->keyId = bin2hex(random_bytes(8));

        $dao = $this->locator->getJWKSDAO();
        $dao->insertRSAKeyPair($keyPair);

        return $keyPair;
    }

    public function getJWKS() {

        $dao = $this->locator->getJWKSDAO();

        $keyPairs = $dao->getRSAKeyPairs();

        $keys = [];
        foreach ($keyPairs as $keyPair) {

            $rsa = new RSA();
            $rsa->loadKey($keyPair->publicKey);
            $rsa->setPublicKey();
            $publicKeyXML = $rsa->getPublicKey(RSA::PUBLIC_FORMAT_XML);

            $rsaModulus = $this->getTagContent($publicKeyXML, "Modulus");
            $rsaExponent = $this->getTagContent($publicKeyXML, "Exponent");

            $urisafeModulus = strtr($rsaModulus, '+/', '-_');

            $jwk = [];
            $jwk['kty'] = "RSA";
            $jwk['kid'] = $keyPair->keyId;
            $jwk['use'] = "sig";
            $jwk['n'] = $urisafeModulus;
            $jwk['e'] = $rsaExponent;

            array_push($keys, $jwk);
        }

        return [
            "keys" => $keys
        ];
    }

    private function getTagContent(string $publicKeyXML, string $tagname): string {
        $matches = [];
        $pattern = "#<\s*?$tagname\b[^>]*>(.*?)</$tagname\b[^>]*>#s";
        preg_match($pattern, $publicKeyXML, $matches);
        return $matches[1];
    }

    public function loadAllJWKS(): array {

        foreach ($this->locator->config->jwksUrls as $url) {
            $this->loadJWKS($url);
        }

        $dao = $this->locator->getJWKSDAO();
        return $dao->getAllPublicJWK();
    }

    private function loadJWKS($url) {

        $dao = $this->locator->getJWKSDAO();

        $conn = curl_init($url);
        curl_setopt($conn, CURLOPT_FOLLOWLOCATION, 1);
        curl_setopt($conn, CURLOPT_RETURNTRANSFER, true);

        $result = curl_exec($conn);
        $info = curl_getinfo($conn);

        if ($info['http_code'] === 200) {
            $jwks = json_decode($result, TRUE);

            foreach ($jwks['keys'] as $key) {
                $key['url'] = $url;
                $jwk = $this->getPublicJWK($key);
                $dao->updatePublicJWK($jwk);
            }
        } else {
            error_log('Error while retrieving JWKS from ' . $url);
        }

        curl_close($conn);
    }

    private function getPublicJWK($data): PublicJWK {

        // Convert Base64 uri-safe variant to default (needed for JWKS)
        $n = strtr($data['n'], '-_', '+/');

        $rsa = new RSA();

        $key = "<RSAKeyPair>"
                . "<Modulus>" . $n . "</Modulus>"
                . "<Exponent>" . $data['e'] . "</Exponent>"
                . "</RSAKeyPair>";

        $rsa->loadKey($key, RSA::PUBLIC_FORMAT_XML);

        $jwk = new PublicJWK();
        $jwk->kid = $data['kid'];
        $jwk->key = $rsa;
        $jwk->url = $data['url'];
        $jwk->updateTime = time();

        return $jwk;
    }

}
