abouttreesummaryrefslogcommitdiff
path: root/ext/olm/src
diff options
context:
space:
mode:
authorpatrick-scho2023-11-13 19:58:33 +0100
committerpatrick-scho2023-11-13 19:58:33 +0100
commitda776f86b42946715c27edd64f7558b9d5080df1 (patch)
treec7e340bb253bd38f73368baeec7f12e914a39955 /ext/olm/src
parent21c6e8484b0bd05c27e5a91f2884d431926adc61 (diff)
downloadmatrix_esp_thesis-da776f86b42946715c27edd64f7558b9d5080df1.tar.gz
matrix_esp_thesis-da776f86b42946715c27edd64f7558b9d5080df1.zip
add dependencies to repo
Diffstat (limited to 'ext/olm/src')
-rw-r--r--ext/olm/src/account.cpp580
-rw-r--r--ext/olm/src/base64.cpp187
-rw-r--r--ext/olm/src/cipher.cpp152
-rw-r--r--ext/olm/src/crypto.cpp299
-rw-r--r--ext/olm/src/ed25519.c22
-rw-r--r--ext/olm/src/error.c46
-rw-r--r--ext/olm/src/inbound_group_session.c540
-rw-r--r--ext/olm/src/megolm.c154
-rw-r--r--ext/olm/src/memory.cpp45
-rw-r--r--ext/olm/src/message.cpp406
-rw-r--r--ext/olm/src/olm.cpp846
-rw-r--r--ext/olm/src/outbound_group_session.c390
-rw-r--r--ext/olm/src/pickle.cpp274
-rw-r--r--ext/olm/src/pickle_encoding.c92
-rw-r--r--ext/olm/src/pk.cpp542
-rw-r--r--ext/olm/src/ratchet.cpp625
-rw-r--r--ext/olm/src/sas.c229
-rw-r--r--ext/olm/src/session.cpp531
-rw-r--r--ext/olm/src/utility.cpp57
19 files changed, 6017 insertions, 0 deletions
diff --git a/ext/olm/src/account.cpp b/ext/olm/src/account.cpp
new file mode 100644
index 0000000..41b7188
--- /dev/null
+++ b/ext/olm/src/account.cpp
@@ -0,0 +1,580 @@
+/* Copyright 2015, 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/account.hh"
+#include "olm/base64.hh"
+#include "olm/pickle.h"
+#include "olm/pickle.hh"
+#include "olm/memory.hh"
+
+olm::Account::Account(
+) : num_fallback_keys(0),
+ next_one_time_key_id(0),
+ last_error(OlmErrorCode::OLM_SUCCESS) {
+}
+
+
+olm::OneTimeKey const * olm::Account::lookup_key(
+ _olm_curve25519_public_key const & public_key
+) {
+ for (olm::OneTimeKey const & key : one_time_keys) {
+ if (olm::array_equal(key.key.public_key.public_key, public_key.public_key)) {
+ return &key;
+ }
+ }
+ if (num_fallback_keys >= 1
+ && olm::array_equal(
+ current_fallback_key.key.public_key.public_key, public_key.public_key
+ )
+ ) {
+ return &current_fallback_key;
+ }
+ if (num_fallback_keys >= 2
+ && olm::array_equal(
+ prev_fallback_key.key.public_key.public_key, public_key.public_key
+ )
+ ) {
+ return &prev_fallback_key;
+ }
+ return 0;
+}
+
+std::size_t olm::Account::remove_key(
+ _olm_curve25519_public_key const & public_key
+) {
+ OneTimeKey * i;
+ for (i = one_time_keys.begin(); i != one_time_keys.end(); ++i) {
+ if (olm::array_equal(i->key.public_key.public_key, public_key.public_key)) {
+ std::uint32_t id = i->id;
+ one_time_keys.erase(i);
+ return id;
+ }
+ }
+ // check if the key is a fallback key, to avoid returning an error, but
+ // don't actually remove it
+ if (num_fallback_keys >= 1
+ && olm::array_equal(
+ current_fallback_key.key.public_key.public_key, public_key.public_key
+ )
+ ) {
+ return current_fallback_key.id;
+ }
+ if (num_fallback_keys >= 2
+ && olm::array_equal(
+ prev_fallback_key.key.public_key.public_key, public_key.public_key
+ )
+ ) {
+ return prev_fallback_key.id;
+ }
+ return std::size_t(-1);
+}
+
+std::size_t olm::Account::new_account_random_length() const {
+ return ED25519_RANDOM_LENGTH + CURVE25519_RANDOM_LENGTH;
+}
+
+std::size_t olm::Account::new_account(
+ uint8_t const * random, std::size_t random_length
+) {
+ if (random_length < new_account_random_length()) {
+ last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+
+ _olm_crypto_ed25519_generate_key(random, &identity_keys.ed25519_key);
+ random += ED25519_RANDOM_LENGTH;
+ _olm_crypto_curve25519_generate_key(random, &identity_keys.curve25519_key);
+
+ return 0;
+}
+
+namespace {
+
+uint8_t KEY_JSON_ED25519[] = "\"ed25519\":";
+uint8_t KEY_JSON_CURVE25519[] = "\"curve25519\":";
+
+template<typename T>
+static std::uint8_t * write_string(
+ std::uint8_t * pos,
+ T const & value
+) {
+ std::memcpy(pos, value, sizeof(T) - 1);
+ return pos + (sizeof(T) - 1);
+}
+
+}
+
+
+std::size_t olm::Account::get_identity_json_length() const {
+ std::size_t length = 0;
+ length += 1; /* { */
+ length += sizeof(KEY_JSON_CURVE25519) - 1;
+ length += 1; /* " */
+ length += olm::encode_base64_length(
+ sizeof(identity_keys.curve25519_key.public_key)
+ );
+ length += 2; /* ", */
+ length += sizeof(KEY_JSON_ED25519) - 1;
+ length += 1; /* " */
+ length += olm::encode_base64_length(
+ sizeof(identity_keys.ed25519_key.public_key)
+ );
+ length += 2; /* "} */
+ return length;
+}
+
+
+std::size_t olm::Account::get_identity_json(
+ std::uint8_t * identity_json, std::size_t identity_json_length
+) {
+ std::uint8_t * pos = identity_json;
+ size_t expected_length = get_identity_json_length();
+
+ if (identity_json_length < expected_length) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ *(pos++) = '{';
+ pos = write_string(pos, KEY_JSON_CURVE25519);
+ *(pos++) = '\"';
+ pos = olm::encode_base64(
+ identity_keys.curve25519_key.public_key.public_key,
+ sizeof(identity_keys.curve25519_key.public_key.public_key),
+ pos
+ );
+ *(pos++) = '\"'; *(pos++) = ',';
+ pos = write_string(pos, KEY_JSON_ED25519);
+ *(pos++) = '\"';
+ pos = olm::encode_base64(
+ identity_keys.ed25519_key.public_key.public_key,
+ sizeof(identity_keys.ed25519_key.public_key.public_key),
+ pos
+ );
+ *(pos++) = '\"'; *(pos++) = '}';
+ return pos - identity_json;
+}
+
+
+std::size_t olm::Account::signature_length(
+) const {
+ return ED25519_SIGNATURE_LENGTH;
+}
+
+
+std::size_t olm::Account::sign(
+ std::uint8_t const * message, std::size_t message_length,
+ std::uint8_t * signature, std::size_t signature_length
+) {
+ if (signature_length < this->signature_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ _olm_crypto_ed25519_sign(
+ &identity_keys.ed25519_key, message, message_length, signature
+ );
+ return this->signature_length();
+}
+
+
+std::size_t olm::Account::get_one_time_keys_json_length(
+) const {
+ std::size_t length = 0;
+ bool is_empty = true;
+ for (auto const & key : one_time_keys) {
+ if (key.published) {
+ continue;
+ }
+ is_empty = false;
+ length += 2; /* {" */
+ length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id));
+ length += 3; /* ":" */
+ length += olm::encode_base64_length(sizeof(key.key.public_key));
+ length += 1; /* " */
+ }
+ if (is_empty) {
+ length += 1; /* { */
+ }
+ length += 3; /* }{} */
+ length += sizeof(KEY_JSON_CURVE25519) - 1;
+ return length;
+}
+
+
+std::size_t olm::Account::get_one_time_keys_json(
+ std::uint8_t * one_time_json, std::size_t one_time_json_length
+) {
+ std::uint8_t * pos = one_time_json;
+ if (one_time_json_length < get_one_time_keys_json_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ *(pos++) = '{';
+ pos = write_string(pos, KEY_JSON_CURVE25519);
+ std::uint8_t sep = '{';
+ for (auto const & key : one_time_keys) {
+ if (key.published) {
+ continue;
+ }
+ *(pos++) = sep;
+ *(pos++) = '\"';
+ std::uint8_t key_id[_olm_pickle_uint32_length(key.id)];
+ _olm_pickle_uint32(key_id, key.id);
+ pos = olm::encode_base64(key_id, sizeof(key_id), pos);
+ *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"';
+ pos = olm::encode_base64(
+ key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos
+ );
+ *(pos++) = '\"';
+ sep = ',';
+ }
+ if (sep != ',') {
+ /* The list was empty */
+ *(pos++) = sep;
+ }
+ *(pos++) = '}';
+ *(pos++) = '}';
+ return pos - one_time_json;
+}
+
+
+std::size_t olm::Account::mark_keys_as_published(
+) {
+ std::size_t count = 0;
+ for (auto & key : one_time_keys) {
+ if (!key.published) {
+ key.published = true;
+ count++;
+ }
+ }
+ current_fallback_key.published = true;
+ return count;
+}
+
+
+std::size_t olm::Account::max_number_of_one_time_keys(
+) const {
+ return olm::MAX_ONE_TIME_KEYS;
+}
+
+std::size_t olm::Account::generate_one_time_keys_random_length(
+ std::size_t number_of_keys
+) const {
+ return CURVE25519_RANDOM_LENGTH * number_of_keys;
+}
+
+std::size_t olm::Account::generate_one_time_keys(
+ std::size_t number_of_keys,
+ std::uint8_t const * random, std::size_t random_length
+) {
+ if (random_length < generate_one_time_keys_random_length(number_of_keys)) {
+ last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+ for (unsigned i = 0; i < number_of_keys; ++i) {
+ OneTimeKey & key = *one_time_keys.insert(one_time_keys.begin());
+ key.id = ++next_one_time_key_id;
+ key.published = false;
+ _olm_crypto_curve25519_generate_key(random, &key.key);
+ random += CURVE25519_RANDOM_LENGTH;
+ }
+ return number_of_keys;
+}
+
+std::size_t olm::Account::generate_fallback_key_random_length() const {
+ return CURVE25519_RANDOM_LENGTH;
+}
+
+std::size_t olm::Account::generate_fallback_key(
+ std::uint8_t const * random, std::size_t random_length
+) {
+ if (random_length < generate_fallback_key_random_length()) {
+ last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+ if (num_fallback_keys < 2) {
+ num_fallback_keys++;
+ }
+ prev_fallback_key = current_fallback_key;
+ current_fallback_key.id = ++next_one_time_key_id;
+ current_fallback_key.published = false;
+ _olm_crypto_curve25519_generate_key(random, &current_fallback_key.key);
+ return 1;
+}
+
+
+std::size_t olm::Account::get_fallback_key_json_length(
+) const {
+ std::size_t length = 4 + sizeof(KEY_JSON_CURVE25519) - 1; /* {"curve25519":{}} */
+ if (num_fallback_keys >= 1) {
+ const OneTimeKey & key = current_fallback_key;
+ length += 1; /* " */
+ length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id));
+ length += 3; /* ":" */
+ length += olm::encode_base64_length(sizeof(key.key.public_key));
+ length += 1; /* " */
+ }
+ return length;
+}
+
+std::size_t olm::Account::get_fallback_key_json(
+ std::uint8_t * fallback_json, std::size_t fallback_json_length
+) {
+ std::uint8_t * pos = fallback_json;
+ if (fallback_json_length < get_fallback_key_json_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ *(pos++) = '{';
+ pos = write_string(pos, KEY_JSON_CURVE25519);
+ *(pos++) = '{';
+ OneTimeKey & key = current_fallback_key;
+ if (num_fallback_keys >= 1) {
+ *(pos++) = '\"';
+ std::uint8_t key_id[_olm_pickle_uint32_length(key.id)];
+ _olm_pickle_uint32(key_id, key.id);
+ pos = olm::encode_base64(key_id, sizeof(key_id), pos);
+ *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"';
+ pos = olm::encode_base64(
+ key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos
+ );
+ *(pos++) = '\"';
+ }
+ *(pos++) = '}';
+ *(pos++) = '}';
+ return pos - fallback_json;
+}
+
+std::size_t olm::Account::get_unpublished_fallback_key_json_length(
+) const {
+ std::size_t length = 4 + sizeof(KEY_JSON_CURVE25519) - 1; /* {"curve25519":{}} */
+ const OneTimeKey & key = current_fallback_key;
+ if (num_fallback_keys >= 1 && !key.published) {
+ length += 1; /* " */
+ length += olm::encode_base64_length(_olm_pickle_uint32_length(key.id));
+ length += 3; /* ":" */
+ length += olm::encode_base64_length(sizeof(key.key.public_key));
+ length += 1; /* " */
+ }
+ return length;
+}
+
+std::size_t olm::Account::get_unpublished_fallback_key_json(
+ std::uint8_t * fallback_json, std::size_t fallback_json_length
+) {
+ std::uint8_t * pos = fallback_json;
+ if (fallback_json_length < get_unpublished_fallback_key_json_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ *(pos++) = '{';
+ pos = write_string(pos, KEY_JSON_CURVE25519);
+ *(pos++) = '{';
+ OneTimeKey & key = current_fallback_key;
+ if (num_fallback_keys >= 1 && !key.published) {
+ *(pos++) = '\"';
+ std::uint8_t key_id[_olm_pickle_uint32_length(key.id)];
+ _olm_pickle_uint32(key_id, key.id);
+ pos = olm::encode_base64(key_id, sizeof(key_id), pos);
+ *(pos++) = '\"'; *(pos++) = ':'; *(pos++) = '\"';
+ pos = olm::encode_base64(
+ key.key.public_key.public_key, sizeof(key.key.public_key.public_key), pos
+ );
+ *(pos++) = '\"';
+ }
+ *(pos++) = '}';
+ *(pos++) = '}';
+ return pos - fallback_json;
+}
+
+void olm::Account::forget_old_fallback_key(
+) {
+ if (num_fallback_keys >= 2) {
+ num_fallback_keys = 1;
+ olm::unset(&prev_fallback_key, sizeof(prev_fallback_key));
+ }
+}
+
+namespace olm {
+
+static std::size_t pickle_length(
+ olm::IdentityKeys const & value
+) {
+ size_t length = 0;
+ length += _olm_pickle_ed25519_key_pair_length(&value.ed25519_key);
+ length += olm::pickle_length(value.curve25519_key);
+ return length;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ olm::IdentityKeys const & value
+) {
+ pos = _olm_pickle_ed25519_key_pair(pos, &value.ed25519_key);
+ pos = olm::pickle(pos, value.curve25519_key);
+ return pos;
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::IdentityKeys & value
+) {
+ pos = _olm_unpickle_ed25519_key_pair(pos, end, &value.ed25519_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.curve25519_key); UNPICKLE_OK(pos);
+ return pos;
+}
+
+
+static std::size_t pickle_length(
+ olm::OneTimeKey const & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(value.id);
+ length += olm::pickle_length(value.published);
+ length += olm::pickle_length(value.key);
+ return length;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ olm::OneTimeKey const & value
+) {
+ pos = olm::pickle(pos, value.id);
+ pos = olm::pickle(pos, value.published);
+ pos = olm::pickle(pos, value.key);
+ return pos;
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::OneTimeKey & value
+) {
+ pos = olm::unpickle(pos, end, value.id); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.published); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.key); UNPICKLE_OK(pos);
+ return pos;
+}
+
+} // namespace olm
+
+namespace {
+// pickle version 1 used only 32 bytes for the ed25519 private key.
+// Any keys thus used should be considered compromised.
+// pickle version 2 does not have fallback keys.
+// pickle version 3 does not store whether the current fallback key is published.
+static const std::uint32_t ACCOUNT_PICKLE_VERSION = 4;
+}
+
+
+std::size_t olm::pickle_length(
+ olm::Account const & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(ACCOUNT_PICKLE_VERSION);
+ length += olm::pickle_length(value.identity_keys);
+ length += olm::pickle_length(value.one_time_keys);
+ length += olm::pickle_length(value.num_fallback_keys);
+ if (value.num_fallback_keys >= 1) {
+ length += olm::pickle_length(value.current_fallback_key);
+ if (value.num_fallback_keys >= 2) {
+ length += olm::pickle_length(value.prev_fallback_key);
+ }
+ }
+ length += olm::pickle_length(value.next_one_time_key_id);
+ return length;
+}
+
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ olm::Account const & value
+) {
+ pos = olm::pickle(pos, ACCOUNT_PICKLE_VERSION);
+ pos = olm::pickle(pos, value.identity_keys);
+ pos = olm::pickle(pos, value.one_time_keys);
+ pos = olm::pickle(pos, value.num_fallback_keys);
+ if (value.num_fallback_keys >= 1) {
+ pos = olm::pickle(pos, value.current_fallback_key);
+ if (value.num_fallback_keys >= 2) {
+ pos = olm::pickle(pos, value.prev_fallback_key);
+ }
+ }
+ pos = olm::pickle(pos, value.next_one_time_key_id);
+ return pos;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::Account & value
+) {
+ uint32_t pickle_version;
+
+ pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos);
+
+ switch (pickle_version) {
+ case ACCOUNT_PICKLE_VERSION:
+ case 3:
+ case 2:
+ break;
+ case 1:
+ value.last_error = OlmErrorCode::OLM_BAD_LEGACY_ACCOUNT_PICKLE;
+ return nullptr;
+ default:
+ value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION;
+ return nullptr;
+ }
+
+ pos = olm::unpickle(pos, end, value.identity_keys); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.one_time_keys); UNPICKLE_OK(pos);
+
+ if (pickle_version <= 2) {
+ // version 2 did not have fallback keys
+ value.num_fallback_keys = 0;
+ } else if (pickle_version == 3) {
+ // version 3 used the published flag to indicate how many fallback keys
+ // were present (we'll have to assume that the keys were published)
+ pos = olm::unpickle(pos, end, value.current_fallback_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.prev_fallback_key); UNPICKLE_OK(pos);
+ if (value.current_fallback_key.published) {
+ if (value.prev_fallback_key.published) {
+ value.num_fallback_keys = 2;
+ } else {
+ value.num_fallback_keys = 1;
+ }
+ } else {
+ value.num_fallback_keys = 0;
+ }
+ } else {
+ pos = olm::unpickle(pos, end, value.num_fallback_keys); UNPICKLE_OK(pos);
+ if (value.num_fallback_keys >= 1) {
+ pos = olm::unpickle(pos, end, value.current_fallback_key); UNPICKLE_OK(pos);
+ if (value.num_fallback_keys >= 2) {
+ pos = olm::unpickle(pos, end, value.prev_fallback_key); UNPICKLE_OK(pos);
+ if (value.num_fallback_keys >= 3) {
+ value.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE;
+ return nullptr;
+ }
+ }
+ }
+ }
+
+ pos = olm::unpickle(pos, end, value.next_one_time_key_id); UNPICKLE_OK(pos);
+
+ return pos;
+}
diff --git a/ext/olm/src/base64.cpp b/ext/olm/src/base64.cpp
new file mode 100644
index 0000000..0e195fb
--- /dev/null
+++ b/ext/olm/src/base64.cpp
@@ -0,0 +1,187 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <cassert>
+
+#include "olm/base64.h"
+#include "olm/base64.hh"
+
+namespace {
+
+static const std::uint8_t ENCODE_BASE64[64] = {
+ 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48,
+ 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50,
+ 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58,
+ 0x59, 0x5A, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66,
+ 0x67, 0x68, 0x69, 0x6A, 0x6B, 0x6C, 0x6D, 0x6E,
+ 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76,
+ 0x77, 0x78, 0x79, 0x7A, 0x30, 0x31, 0x32, 0x33,
+ 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x2B, 0x2F,
+};
+
+static const std::uint8_t E = -1;
+
+static const std::uint8_t DECODE_BASE64[128] = {
+/* 0x0 0x1 0x2 0x3 0x4 0x5 0x6 0x7 0x8 0x9 0xA 0xB 0xC 0xD 0xE 0xF */
+ E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, E,
+ E, E, E, E, E, E, E, E, E, E, E, E, E, E, E, E,
+ E, E, E, E, E, E, E, E, E, E, E, 62, E, E, E, 63,
+ 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, E, E, E, E, E, E,
+ E, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
+ 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, E, E, E, E, E,
+ E, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
+ 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, E, E, E, E, E,
+};
+
+} // namespace
+
+
+std::size_t olm::encode_base64_length(
+ std::size_t input_length
+) {
+ return 4 * ((input_length + 2) / 3) + (input_length + 2) % 3 - 2;
+}
+
+std::uint8_t * olm::encode_base64(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ std::uint8_t const * end = input + (input_length / 3) * 3;
+ std::uint8_t const * pos = input;
+ while (pos != end) {
+ unsigned value = pos[0];
+ value <<= 8; value |= pos[1];
+ value <<= 8; value |= pos[2];
+ pos += 3;
+ output[3] = ENCODE_BASE64[value & 0x3F];
+ value >>= 6; output[2] = ENCODE_BASE64[value & 0x3F];
+ value >>= 6; output[1] = ENCODE_BASE64[value & 0x3F];
+ value >>= 6; output[0] = ENCODE_BASE64[value];
+ output += 4;
+ }
+ unsigned remainder = input + input_length - pos;
+ std::uint8_t * result = output;
+ if (remainder) {
+ unsigned value = pos[0];
+ if (remainder == 2) {
+ value <<= 8; value |= pos[1];
+ value <<= 2;
+ output[2] = ENCODE_BASE64[value & 0x3F];
+ value >>= 6;
+ result += 3;
+ } else {
+ value <<= 4;
+ result += 2;
+ }
+ output[1] = ENCODE_BASE64[value & 0x3F];
+ value >>= 6;
+ output[0] = ENCODE_BASE64[value];
+ }
+ return result;
+}
+
+
+std::size_t olm::decode_base64_length(
+ std::size_t input_length
+) {
+ if (input_length % 4 == 1) {
+ return std::size_t(-1);
+ } else {
+ return 3 * ((input_length + 2) / 4) + (input_length + 2) % 4 - 2;
+ }
+}
+
+
+std::size_t olm::decode_base64(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ size_t raw_length = olm::decode_base64_length(input_length);
+
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+
+ std::uint8_t const * end = input + (input_length / 4) * 4;
+ std::uint8_t const * pos = input;
+
+ while (pos != end) {
+ unsigned value = DECODE_BASE64[pos[0] & 0x7F];
+ value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F];
+ value <<= 6; value |= DECODE_BASE64[pos[2] & 0x7F];
+ value <<= 6; value |= DECODE_BASE64[pos[3] & 0x7F];
+ pos += 4;
+ output[2] = value;
+ value >>= 8; output[1] = value;
+ value >>= 8; output[0] = value;
+ output += 3;
+ }
+
+ unsigned remainder = input + input_length - pos;
+ if (remainder) {
+ /* A base64 payload with a single byte remainder cannot occur because
+ * a single base64 character only encodes 6 bits, which is less than
+ * a full byte. Therefore, a minimum of two base64 characters are
+ * required to construct a single output byte and payloads with
+ * a remainder of 1 are illegal.
+ *
+ * Should never be the case due to length check above.
+ */
+ assert(remainder != 1);
+
+ unsigned value = DECODE_BASE64[pos[0] & 0x7F];
+ value <<= 6; value |= DECODE_BASE64[pos[1] & 0x7F];
+ if (remainder == 3) {
+ value <<= 6; value |= DECODE_BASE64[pos[2] & 0x7F];
+ value >>= 2;
+ output[1] = value;
+ value >>= 8;
+ } else {
+ value >>= 4;
+ }
+ output[0] = value;
+ }
+
+ return raw_length;
+}
+
+
+// implementations of base64.h
+
+size_t _olm_encode_base64_length(
+ size_t input_length
+) {
+ return olm::encode_base64_length(input_length);
+}
+
+size_t _olm_encode_base64(
+ uint8_t const * input, size_t input_length,
+ uint8_t * output
+) {
+ uint8_t * r = olm::encode_base64(input, input_length, output);
+ return r - output;
+}
+
+size_t _olm_decode_base64_length(
+ size_t input_length
+) {
+ return olm::decode_base64_length(input_length);
+}
+
+size_t _olm_decode_base64(
+ uint8_t const * input, size_t input_length,
+ uint8_t * output
+) {
+ return olm::decode_base64(input, input_length, output);
+}
diff --git a/ext/olm/src/cipher.cpp b/ext/olm/src/cipher.cpp
new file mode 100644
index 0000000..2312b84
--- /dev/null
+++ b/ext/olm/src/cipher.cpp
@@ -0,0 +1,152 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/memory.hh"
+#include <cstring>
+
+const std::size_t HMAC_KEY_LENGTH = 32;
+
+namespace {
+
+struct DerivedKeys {
+ _olm_aes256_key aes_key;
+ std::uint8_t mac_key[HMAC_KEY_LENGTH];
+ _olm_aes256_iv aes_iv;
+};
+
+
+static void derive_keys(
+ std::uint8_t const * kdf_info, std::size_t kdf_info_length,
+ std::uint8_t const * key, std::size_t key_length,
+ DerivedKeys & keys
+) {
+ std::uint8_t derived_secrets[
+ AES256_KEY_LENGTH + HMAC_KEY_LENGTH + AES256_IV_LENGTH
+ ];
+ _olm_crypto_hkdf_sha256(
+ key, key_length,
+ nullptr, 0,
+ kdf_info, kdf_info_length,
+ derived_secrets, sizeof(derived_secrets)
+ );
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(keys.aes_key.key, pos);
+ pos = olm::load_array(keys.mac_key, pos);
+ pos = olm::load_array(keys.aes_iv.iv, pos);
+ olm::unset(derived_secrets);
+}
+
+static const std::size_t MAC_LENGTH = 8;
+
+size_t aes_sha_256_cipher_mac_length(const struct _olm_cipher *cipher) {
+ return MAC_LENGTH;
+}
+
+size_t aes_sha_256_cipher_encrypt_ciphertext_length(
+ const struct _olm_cipher *cipher, size_t plaintext_length
+) {
+ return _olm_crypto_aes_encrypt_cbc_length(plaintext_length);
+}
+
+size_t aes_sha_256_cipher_encrypt(
+ const struct _olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * ciphertext, size_t ciphertext_length,
+ uint8_t * output, size_t output_length
+) {
+ auto *c = reinterpret_cast<const _olm_cipher_aes_sha_256 *>(cipher);
+
+ if (ciphertext_length
+ < aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length)
+ || output_length < MAC_LENGTH) {
+ return std::size_t(-1);
+ }
+
+ struct DerivedKeys keys;
+ std::uint8_t mac[SHA256_OUTPUT_LENGTH];
+
+ derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
+
+ _olm_crypto_aes_encrypt_cbc(
+ &keys.aes_key, &keys.aes_iv, plaintext, plaintext_length, ciphertext
+ );
+
+ _olm_crypto_hmac_sha256(
+ keys.mac_key, HMAC_KEY_LENGTH, output, output_length - MAC_LENGTH, mac
+ );
+
+ std::memcpy(output + output_length - MAC_LENGTH, mac, MAC_LENGTH);
+
+ olm::unset(keys);
+ return output_length;
+}
+
+
+size_t aes_sha_256_cipher_decrypt_max_plaintext_length(
+ const struct _olm_cipher *cipher,
+ size_t ciphertext_length
+) {
+ return ciphertext_length;
+}
+
+size_t aes_sha_256_cipher_decrypt(
+ const struct _olm_cipher *cipher,
+ uint8_t const * key, size_t key_length,
+ uint8_t const * input, size_t input_length,
+ uint8_t const * ciphertext, size_t ciphertext_length,
+ uint8_t * plaintext, size_t max_plaintext_length
+) {
+ if (max_plaintext_length
+ < aes_sha_256_cipher_decrypt_max_plaintext_length(cipher, ciphertext_length)
+ || input_length < MAC_LENGTH) {
+ return std::size_t(-1);
+ }
+
+ auto *c = reinterpret_cast<const _olm_cipher_aes_sha_256 *>(cipher);
+
+ DerivedKeys keys;
+ std::uint8_t mac[SHA256_OUTPUT_LENGTH];
+
+ derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
+
+ _olm_crypto_hmac_sha256(
+ keys.mac_key, HMAC_KEY_LENGTH, input, input_length - MAC_LENGTH, mac
+ );
+
+ std::uint8_t const * input_mac = input + input_length - MAC_LENGTH;
+ if (!olm::is_equal(input_mac, mac, MAC_LENGTH)) {
+ olm::unset(keys);
+ return std::size_t(-1);
+ }
+
+ std::size_t plaintext_length = _olm_crypto_aes_decrypt_cbc(
+ &keys.aes_key, &keys.aes_iv, ciphertext, ciphertext_length, plaintext
+ );
+
+ olm::unset(keys);
+ return plaintext_length;
+}
+
+} // namespace
+
+const struct _olm_cipher_ops _olm_cipher_aes_sha_256_ops = {
+ aes_sha_256_cipher_mac_length,
+ aes_sha_256_cipher_encrypt_ciphertext_length,
+ aes_sha_256_cipher_encrypt,
+ aes_sha_256_cipher_decrypt_max_plaintext_length,
+ aes_sha_256_cipher_decrypt,
+};
diff --git a/ext/olm/src/crypto.cpp b/ext/olm/src/crypto.cpp
new file mode 100644
index 0000000..e297513
--- /dev/null
+++ b/ext/olm/src/crypto.cpp
@@ -0,0 +1,299 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/crypto.h"
+#include "olm/memory.hh"
+
+#include <cstring>
+
+extern "C" {
+
+#include "crypto-algorithms/aes.h"
+#include "crypto-algorithms/sha256.h"
+
+}
+
+#include "ed25519/src/ed25519.h"
+#include "curve25519-donna.h"
+
+namespace {
+
+static const std::uint8_t CURVE25519_BASEPOINT[32] = {9};
+static const std::size_t AES_KEY_SCHEDULE_LENGTH = 60;
+static const std::size_t AES_KEY_BITS = 8 * AES256_KEY_LENGTH;
+static const std::size_t AES_BLOCK_LENGTH = 16;
+static const std::size_t SHA256_BLOCK_LENGTH = 64;
+static const std::uint8_t HKDF_DEFAULT_SALT[32] = {};
+
+
+template<std::size_t block_size>
+inline static void xor_block(
+ std::uint8_t * block,
+ std::uint8_t const * input
+) {
+ for (std::size_t i = 0; i < block_size; ++i) {
+ block[i] ^= input[i];
+ }
+}
+
+
+inline static void hmac_sha256_key(
+ std::uint8_t const * input_key, std::size_t input_key_length,
+ std::uint8_t * hmac_key
+) {
+ std::memset(hmac_key, 0, SHA256_BLOCK_LENGTH);
+ if (input_key_length > SHA256_BLOCK_LENGTH) {
+ ::SHA256_CTX context;
+ ::sha256_init(&context);
+ ::sha256_update(&context, input_key, input_key_length);
+ ::sha256_final(&context, hmac_key);
+ } else {
+ std::memcpy(hmac_key, input_key, input_key_length);
+ }
+}
+
+
+inline static void hmac_sha256_init(
+ ::SHA256_CTX * context,
+ std::uint8_t const * hmac_key
+) {
+ std::uint8_t i_pad[SHA256_BLOCK_LENGTH];
+ std::memcpy(i_pad, hmac_key, SHA256_BLOCK_LENGTH);
+ for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) {
+ i_pad[i] ^= 0x36;
+ }
+ ::sha256_init(context);
+ ::sha256_update(context, i_pad, SHA256_BLOCK_LENGTH);
+ olm::unset(i_pad);
+}
+
+
+inline static void hmac_sha256_final(
+ ::SHA256_CTX * context,
+ std::uint8_t const * hmac_key,
+ std::uint8_t * output
+) {
+ std::uint8_t o_pad[SHA256_BLOCK_LENGTH + SHA256_OUTPUT_LENGTH];
+ std::memcpy(o_pad, hmac_key, SHA256_BLOCK_LENGTH);
+ for (std::size_t i = 0; i < SHA256_BLOCK_LENGTH; ++i) {
+ o_pad[i] ^= 0x5C;
+ }
+ ::sha256_final(context, o_pad + SHA256_BLOCK_LENGTH);
+ ::SHA256_CTX final_context;
+ ::sha256_init(&final_context);
+ ::sha256_update(&final_context, o_pad, sizeof(o_pad));
+ ::sha256_final(&final_context, output);
+ olm::unset(final_context);
+ olm::unset(o_pad);
+}
+
+} // namespace
+
+void _olm_crypto_curve25519_generate_key(
+ uint8_t const * random_32_bytes,
+ struct _olm_curve25519_key_pair *key_pair
+) {
+ std::memcpy(
+ key_pair->private_key.private_key, random_32_bytes,
+ CURVE25519_KEY_LENGTH
+ );
+ ::curve25519_donna(
+ key_pair->public_key.public_key,
+ key_pair->private_key.private_key,
+ CURVE25519_BASEPOINT
+ );
+}
+
+
+void _olm_crypto_curve25519_shared_secret(
+ const struct _olm_curve25519_key_pair *our_key,
+ const struct _olm_curve25519_public_key * their_key,
+ std::uint8_t * output
+) {
+ ::curve25519_donna(output, our_key->private_key.private_key, their_key->public_key);
+}
+
+
+void _olm_crypto_ed25519_generate_key(
+ std::uint8_t const * random_32_bytes,
+ struct _olm_ed25519_key_pair *key_pair
+) {
+ ::ed25519_create_keypair(
+ key_pair->public_key.public_key, key_pair->private_key.private_key,
+ random_32_bytes
+ );
+}
+
+
+void _olm_crypto_ed25519_sign(
+ const struct _olm_ed25519_key_pair *our_key,
+ std::uint8_t const * message, std::size_t message_length,
+ std::uint8_t * output
+) {
+ ::ed25519_sign(
+ output,
+ message, message_length,
+ our_key->public_key.public_key,
+ our_key->private_key.private_key
+ );
+}
+
+
+int _olm_crypto_ed25519_verify(
+ const struct _olm_ed25519_public_key *their_key,
+ std::uint8_t const * message, std::size_t message_length,
+ std::uint8_t const * signature
+) {
+ return 0 != ::ed25519_verify(
+ signature,
+ message, message_length,
+ their_key->public_key
+ );
+}
+
+
+std::size_t _olm_crypto_aes_encrypt_cbc_length(
+ std::size_t input_length
+) {
+ return input_length + AES_BLOCK_LENGTH - input_length % AES_BLOCK_LENGTH;
+}
+
+
+void _olm_crypto_aes_encrypt_cbc(
+ _olm_aes256_key const *key,
+ _olm_aes256_iv const *iv,
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH];
+ ::_olm_aes_key_setup(key->key, key_schedule, AES_KEY_BITS);
+ std::uint8_t input_block[AES_BLOCK_LENGTH];
+ std::memcpy(input_block, iv->iv, AES_BLOCK_LENGTH);
+ while (input_length >= AES_BLOCK_LENGTH) {
+ xor_block<AES_BLOCK_LENGTH>(input_block, input);
+ ::_olm_aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS);
+ std::memcpy(input_block, output, AES_BLOCK_LENGTH);
+ input += AES_BLOCK_LENGTH;
+ output += AES_BLOCK_LENGTH;
+ input_length -= AES_BLOCK_LENGTH;
+ }
+ std::size_t i = 0;
+ for (; i < input_length; ++i) {
+ input_block[i] ^= input[i];
+ }
+ for (; i < AES_BLOCK_LENGTH; ++i) {
+ input_block[i] ^= AES_BLOCK_LENGTH - input_length;
+ }
+ ::_olm_aes_encrypt(input_block, output, key_schedule, AES_KEY_BITS);
+ olm::unset(key_schedule);
+ olm::unset(input_block);
+}
+
+
+std::size_t _olm_crypto_aes_decrypt_cbc(
+ _olm_aes256_key const *key,
+ _olm_aes256_iv const *iv,
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ std::uint32_t key_schedule[AES_KEY_SCHEDULE_LENGTH];
+ ::_olm_aes_key_setup(key->key, key_schedule, AES_KEY_BITS);
+ std::uint8_t block1[AES_BLOCK_LENGTH];
+ std::uint8_t block2[AES_BLOCK_LENGTH];
+ std::memcpy(block1, iv->iv, AES_BLOCK_LENGTH);
+ for (std::size_t i = 0; i < input_length; i += AES_BLOCK_LENGTH) {
+ std::memcpy(block2, &input[i], AES_BLOCK_LENGTH);
+ ::_olm_aes_decrypt(&input[i], &output[i], key_schedule, AES_KEY_BITS);
+ xor_block<AES_BLOCK_LENGTH>(&output[i], block1);
+ std::memcpy(block1, block2, AES_BLOCK_LENGTH);
+ }
+ olm::unset(key_schedule);
+ olm::unset(block1);
+ olm::unset(block2);
+ std::size_t padding = output[input_length - 1];
+ return (padding > input_length) ? std::size_t(-1) : (input_length - padding);
+}
+
+
+void _olm_crypto_sha256(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ ::SHA256_CTX context;
+ ::sha256_init(&context);
+ ::sha256_update(&context, input, input_length);
+ ::sha256_final(&context, output);
+ olm::unset(context);
+}
+
+
+void _olm_crypto_hmac_sha256(
+ std::uint8_t const * key, std::size_t key_length,
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output
+) {
+ std::uint8_t hmac_key[SHA256_BLOCK_LENGTH];
+ ::SHA256_CTX context;
+ hmac_sha256_key(key, key_length, hmac_key);
+ hmac_sha256_init(&context, hmac_key);
+ ::sha256_update(&context, input, input_length);
+ hmac_sha256_final(&context, hmac_key, output);
+ olm::unset(hmac_key);
+ olm::unset(context);
+}
+
+
+void _olm_crypto_hkdf_sha256(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t const * salt, std::size_t salt_length,
+ std::uint8_t const * info, std::size_t info_length,
+ std::uint8_t * output, std::size_t output_length
+) {
+ ::SHA256_CTX context;
+ std::uint8_t hmac_key[SHA256_BLOCK_LENGTH];
+ std::uint8_t step_result[SHA256_OUTPUT_LENGTH];
+ std::size_t bytes_remaining = output_length;
+ std::uint8_t iteration = 1;
+ if (!salt) {
+ salt = HKDF_DEFAULT_SALT;
+ salt_length = sizeof(HKDF_DEFAULT_SALT);
+ }
+ /* Extract */
+ hmac_sha256_key(salt, salt_length, hmac_key);
+ hmac_sha256_init(&context, hmac_key);
+ ::sha256_update(&context, input, input_length);
+ hmac_sha256_final(&context, hmac_key, step_result);
+ hmac_sha256_key(step_result, SHA256_OUTPUT_LENGTH, hmac_key);
+
+ /* Expand */
+ hmac_sha256_init(&context, hmac_key);
+ ::sha256_update(&context, info, info_length);
+ ::sha256_update(&context, &iteration, 1);
+ hmac_sha256_final(&context, hmac_key, step_result);
+ while (bytes_remaining > SHA256_OUTPUT_LENGTH) {
+ std::memcpy(output, step_result, SHA256_OUTPUT_LENGTH);
+ output += SHA256_OUTPUT_LENGTH;
+ bytes_remaining -= SHA256_OUTPUT_LENGTH;
+ iteration ++;
+ hmac_sha256_init(&context, hmac_key);
+ ::sha256_update(&context, step_result, SHA256_OUTPUT_LENGTH);
+ ::sha256_update(&context, info, info_length);
+ ::sha256_update(&context, &iteration, 1);
+ hmac_sha256_final(&context, hmac_key, step_result);
+ }
+ std::memcpy(output, step_result, bytes_remaining);
+ olm::unset(context);
+ olm::unset(hmac_key);
+ olm::unset(step_result);
+}
diff --git a/ext/olm/src/ed25519.c b/ext/olm/src/ed25519.c
new file mode 100644
index 0000000..c7a1a8e
--- /dev/null
+++ b/ext/olm/src/ed25519.c
@@ -0,0 +1,22 @@
+/* Copyright 2015-6 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#define select ed25519_select
+#include "ed25519/src/fe.c"
+#include "ed25519/src/sc.c"
+#include "ed25519/src/ge.c"
+#include "ed25519/src/keypair.c"
+#include "ed25519/src/sha512.c"
+#include "ed25519/src/verify.c"
+#include "ed25519/src/sign.c"
diff --git a/ext/olm/src/error.c b/ext/olm/src/error.c
new file mode 100644
index 0000000..6775eee
--- /dev/null
+++ b/ext/olm/src/error.c
@@ -0,0 +1,46 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/error.h"
+
+static const char * ERRORS[] = {
+ "SUCCESS",
+ "NOT_ENOUGH_RANDOM",
+ "OUTPUT_BUFFER_TOO_SMALL",
+ "BAD_MESSAGE_VERSION",
+ "BAD_MESSAGE_FORMAT",
+ "BAD_MESSAGE_MAC",
+ "BAD_MESSAGE_KEY_ID",
+ "INVALID_BASE64",
+ "BAD_ACCOUNT_KEY",
+ "UNKNOWN_PICKLE_VERSION",
+ "CORRUPTED_PICKLE",
+ "BAD_SESSION_KEY",
+ "UNKNOWN_MESSAGE_INDEX",
+ "BAD_LEGACY_ACCOUNT_PICKLE",
+ "BAD_SIGNATURE",
+ "OLM_INPUT_BUFFER_TOO_SMALL",
+ "OLM_SAS_THEIR_KEY_NOT_SET",
+ "OLM_PICKLE_EXTRA_DATA"
+};
+
+const char * _olm_error_to_string(enum OlmErrorCode error)
+{
+ if (error < (sizeof(ERRORS)/sizeof(ERRORS[0]))) {
+ return ERRORS[error];
+ } else {
+ return "UNKNOWN_ERROR";
+ }
+}
diff --git a/ext/olm/src/inbound_group_session.c b/ext/olm/src/inbound_group_session.c
new file mode 100644
index 0000000..d6f73b7
--- /dev/null
+++ b/ext/olm/src/inbound_group_session.c
@@ -0,0 +1,540 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/inbound_group_session.h"
+
+#include <string.h>
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/error.h"
+#include "olm/megolm.h"
+#include "olm/memory.h"
+#include "olm/message.h"
+#include "olm/pickle.h"
+#include "olm/pickle_encoding.h"
+
+
+#define OLM_PROTOCOL_VERSION 3
+#define GROUP_SESSION_ID_LENGTH ED25519_PUBLIC_KEY_LENGTH
+#define PICKLE_VERSION 2
+#define SESSION_KEY_VERSION 2
+#define SESSION_EXPORT_VERSION 1
+
+struct OlmInboundGroupSession {
+ /** our earliest known ratchet value */
+ Megolm initial_ratchet;
+
+ /** The most recent ratchet value */
+ Megolm latest_ratchet;
+
+ /** The ed25519 signing key */
+ struct _olm_ed25519_public_key signing_key;
+
+ /**
+ * Have we ever seen any evidence that this is a valid session?
+ * (either because the original session share was signed, or because we
+ * have subsequently successfully decrypted a message)
+ *
+ * (We don't do anything with this currently, but we may want to bear it in
+ * mind when we consider handling key-shares for sessions we already know
+ * about.)
+ */
+ int signing_key_verified;
+
+ enum OlmErrorCode last_error;
+};
+
+size_t olm_inbound_group_session_size(void) {
+ return sizeof(OlmInboundGroupSession);
+}
+
+OlmInboundGroupSession * olm_inbound_group_session(
+ void *memory
+) {
+ OlmInboundGroupSession *session = memory;
+ olm_clear_inbound_group_session(session);
+ return session;
+}
+
+const char *olm_inbound_group_session_last_error(
+ const OlmInboundGroupSession *session
+) {
+ return _olm_error_to_string(session->last_error);
+}
+
+enum OlmErrorCode olm_inbound_group_session_last_error_code(
+ const OlmInboundGroupSession *session
+) {
+ return session->last_error;
+}
+
+size_t olm_clear_inbound_group_session(
+ OlmInboundGroupSession *session
+) {
+ _olm_unset(session, sizeof(OlmInboundGroupSession));
+ return sizeof(OlmInboundGroupSession);
+}
+
+#define SESSION_EXPORT_RAW_LENGTH \
+ (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
+
+#define SESSION_KEY_RAW_LENGTH \
+ (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH\
+ + ED25519_SIGNATURE_LENGTH)
+
+static size_t _init_group_session_keys(
+ OlmInboundGroupSession *session,
+ const uint8_t *key_buf,
+ int export_format
+) {
+ const uint8_t expected_version =
+ (export_format ? SESSION_EXPORT_VERSION : SESSION_KEY_VERSION);
+ const uint8_t *ptr = key_buf;
+ size_t version = *ptr++;
+
+ if (version != expected_version) {
+ session->last_error = OLM_BAD_SESSION_KEY;
+ return (size_t)-1;
+ }
+
+ uint32_t counter = 0;
+ // Decode counter as a big endian 32-bit number.
+ for (unsigned i = 0; i < 4; i++) {
+ counter <<= 8; counter |= *ptr++;
+ }
+
+ megolm_init(&session->initial_ratchet, ptr, counter);
+ megolm_init(&session->latest_ratchet, ptr, counter);
+
+ ptr += MEGOLM_RATCHET_LENGTH;
+ memcpy(
+ session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH
+ );
+ ptr += ED25519_PUBLIC_KEY_LENGTH;
+
+ if (!export_format) {
+ if (!_olm_crypto_ed25519_verify(&session->signing_key, key_buf,
+ ptr - key_buf, ptr)) {
+ session->last_error = OLM_BAD_SIGNATURE;
+ return (size_t)-1;
+ }
+
+ /* signed keyshare */
+ session->signing_key_verified = 1;
+ }
+ return 0;
+}
+
+size_t olm_init_inbound_group_session(
+ OlmInboundGroupSession *session,
+ const uint8_t * session_key, size_t session_key_length
+) {
+ uint8_t key_buf[SESSION_KEY_RAW_LENGTH];
+ size_t raw_length = _olm_decode_base64_length(session_key_length);
+ size_t result;
+
+ if (raw_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ if (raw_length != SESSION_KEY_RAW_LENGTH) {
+ session->last_error = OLM_BAD_SESSION_KEY;
+ return (size_t)-1;
+ }
+
+ _olm_decode_base64(session_key, session_key_length, key_buf);
+ result = _init_group_session_keys(session, key_buf, 0);
+ _olm_unset(key_buf, SESSION_KEY_RAW_LENGTH);
+ return result;
+}
+
+size_t olm_import_inbound_group_session(
+ OlmInboundGroupSession *session,
+ const uint8_t * session_key, size_t session_key_length
+) {
+ uint8_t key_buf[SESSION_EXPORT_RAW_LENGTH];
+ size_t raw_length = _olm_decode_base64_length(session_key_length);
+ size_t result;
+
+ if (raw_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ if (raw_length != SESSION_EXPORT_RAW_LENGTH) {
+ session->last_error = OLM_BAD_SESSION_KEY;
+ return (size_t)-1;
+ }
+
+ _olm_decode_base64(session_key, session_key_length, key_buf);
+ result = _init_group_session_keys(session, key_buf, 1);
+ _olm_unset(key_buf, SESSION_EXPORT_RAW_LENGTH);
+ return result;
+}
+
+static size_t raw_pickle_length(
+ const OlmInboundGroupSession *session
+) {
+ size_t length = 0;
+ length += _olm_pickle_uint32_length(PICKLE_VERSION);
+ length += megolm_pickle_length(&session->initial_ratchet);
+ length += megolm_pickle_length(&session->latest_ratchet);
+ length += _olm_pickle_ed25519_public_key_length(&session->signing_key);
+ length += _olm_pickle_bool_length(session->signing_key_verified);
+ return length;
+}
+
+size_t olm_pickle_inbound_group_session_length(
+ const OlmInboundGroupSession *session
+) {
+ return _olm_enc_output_length(raw_pickle_length(session));
+}
+
+size_t olm_pickle_inbound_group_session(
+ OlmInboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ size_t raw_length = raw_pickle_length(session);
+ uint8_t *pos;
+
+ if (pickled_length < _olm_enc_output_length(raw_length)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ pos = _olm_enc_output_pos(pickled, raw_length);
+ pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
+ pos = megolm_pickle(&session->initial_ratchet, pos);
+ pos = megolm_pickle(&session->latest_ratchet, pos);
+ pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key);
+ pos = _olm_pickle_bool(pos, session->signing_key_verified);
+
+ return _olm_enc_output(key, key_length, pickled, raw_length);
+}
+
+size_t olm_unpickle_inbound_group_session(
+ OlmInboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ const uint8_t *pos;
+ const uint8_t *end;
+ uint32_t pickle_version;
+
+ size_t raw_length = _olm_enc_input(
+ key, key_length, pickled, pickled_length, &(session->last_error)
+ );
+ if (raw_length == (size_t)-1) {
+ return raw_length;
+ }
+
+ pos = pickled;
+ end = pos + raw_length;
+
+ pos = _olm_unpickle_uint32(pos, end, &pickle_version);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ if (pickle_version < 1 || pickle_version > PICKLE_VERSION) {
+ session->last_error = OLM_UNKNOWN_PICKLE_VERSION;
+ return (size_t)-1;
+ }
+
+ pos = megolm_unpickle(&session->initial_ratchet, pos, end);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ pos = megolm_unpickle(&session->latest_ratchet, pos, end);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ if (pickle_version == 1) {
+ /* pickle v1 had no signing_key_verified field (all keyshares were
+ * verified at import time) */
+ session->signing_key_verified = 1;
+ } else {
+ pos = _olm_unpickle_bool(pos, end, &(session->signing_key_verified));
+ }
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ if (pos != end) {
+ /* Input was longer than expected. */
+ session->last_error = OLM_PICKLE_EXTRA_DATA;
+ return (size_t)-1;
+ }
+
+ return pickled_length;
+}
+
+/**
+ * get the max plaintext length in an un-base64-ed message
+ */
+static size_t _decrypt_max_plaintext_length(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length
+) {
+ struct _OlmDecodeGroupMessageResults decoded_results;
+
+ _olm_decode_group_message(
+ message, message_length,
+ megolm_cipher->ops->mac_length(megolm_cipher),
+ ED25519_SIGNATURE_LENGTH,
+ &decoded_results);
+
+ if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+ session->last_error = OLM_BAD_MESSAGE_VERSION;
+ return (size_t)-1;
+ }
+
+ if (!decoded_results.ciphertext) {
+ session->last_error = OLM_BAD_MESSAGE_FORMAT;
+ return (size_t)-1;
+ }
+
+ return megolm_cipher->ops->decrypt_max_plaintext_length(
+ megolm_cipher, decoded_results.ciphertext_length);
+}
+
+size_t olm_group_decrypt_max_plaintext_length(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length
+) {
+ size_t raw_length;
+
+ raw_length = _olm_decode_base64(message, message_length, message);
+ if (raw_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ return _decrypt_max_plaintext_length(
+ session, message, raw_length
+ );
+}
+
+/**
+ * get a copy of the megolm ratchet, advanced
+ * to the relevant index. Returns 0 on success, -1 on error
+ */
+static size_t _get_megolm(
+ OlmInboundGroupSession *session, uint32_t message_index, Megolm *result
+) {
+ /* pick a megolm instance to use. If we're at or beyond the latest ratchet
+ * value, use that */
+ if ((message_index - session->latest_ratchet.counter) < (1U << 31)) {
+ megolm_advance_to(&session->latest_ratchet, message_index);
+ *result = session->latest_ratchet;
+ return 0;
+ } else if ((message_index - session->initial_ratchet.counter) >= (1U << 31)) {
+ /* the counter is before our intial ratchet - we can't decode this. */
+ session->last_error = OLM_UNKNOWN_MESSAGE_INDEX;
+ return (size_t)-1;
+ } else {
+ /* otherwise, start from the initial megolm. Take a copy so that we
+ * don't overwrite the initial megolm */
+ *result = session->initial_ratchet;
+ megolm_advance_to(result, message_index);
+ return 0;
+ }
+}
+
+/**
+ * decrypt an un-base64-ed message
+ */
+static size_t _decrypt(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length,
+ uint8_t * plaintext, size_t max_plaintext_length,
+ uint32_t * message_index
+) {
+ struct _OlmDecodeGroupMessageResults decoded_results;
+ size_t max_length, r;
+ Megolm megolm;
+
+ _olm_decode_group_message(
+ message, message_length,
+ megolm_cipher->ops->mac_length(megolm_cipher),
+ ED25519_SIGNATURE_LENGTH,
+ &decoded_results);
+
+ if (decoded_results.version != OLM_PROTOCOL_VERSION) {
+ session->last_error = OLM_BAD_MESSAGE_VERSION;
+ return (size_t)-1;
+ }
+
+ if (!decoded_results.has_message_index || !decoded_results.ciphertext) {
+ session->last_error = OLM_BAD_MESSAGE_FORMAT;
+ return (size_t)-1;
+ }
+
+ if (message_index != NULL) {
+ *message_index = decoded_results.message_index;
+ }
+
+ /* verify the signature. We could do this before decoding the message, but
+ * we allow for the possibility of future protocol versions which use a
+ * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
+ * than "BAD_SIGNATURE" in this case.
+ */
+ message_length -= ED25519_SIGNATURE_LENGTH;
+ r = _olm_crypto_ed25519_verify(
+ &session->signing_key,
+ message, message_length,
+ message + message_length
+ );
+ if (!r) {
+ session->last_error = OLM_BAD_SIGNATURE;
+ return (size_t)-1;
+ }
+
+ max_length = megolm_cipher->ops->decrypt_max_plaintext_length(
+ megolm_cipher,
+ decoded_results.ciphertext_length
+ );
+ if (max_plaintext_length < max_length) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ r = _get_megolm(session, decoded_results.message_index, &megolm);
+ if (r == (size_t)-1) {
+ return r;
+ }
+
+ /* now try checking the mac, and decrypting */
+ r = megolm_cipher->ops->decrypt(
+ megolm_cipher,
+ megolm_get_data(&megolm), MEGOLM_RATCHET_LENGTH,
+ message, message_length,
+ decoded_results.ciphertext, decoded_results.ciphertext_length,
+ plaintext, max_plaintext_length
+ );
+
+ _olm_unset(&megolm, sizeof(megolm));
+ if (r == (size_t)-1) {
+ session->last_error = OLM_BAD_MESSAGE_MAC;
+ return r;
+ }
+
+ /* once we have successfully decrypted a message, set a flag to say the
+ * session appears valid. */
+ session->signing_key_verified = 1;
+
+ return r;
+}
+
+size_t olm_group_decrypt(
+ OlmInboundGroupSession *session,
+ uint8_t * message, size_t message_length,
+ uint8_t * plaintext, size_t max_plaintext_length,
+ uint32_t * message_index
+) {
+ size_t raw_message_length;
+
+ raw_message_length = _olm_decode_base64(message, message_length, message);
+ if (raw_message_length == (size_t)-1) {
+ session->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ return _decrypt(
+ session, message, raw_message_length,
+ plaintext, max_plaintext_length,
+ message_index
+ );
+}
+
+size_t olm_inbound_group_session_id_length(
+ const OlmInboundGroupSession *session
+) {
+ return _olm_encode_base64_length(GROUP_SESSION_ID_LENGTH);
+}
+
+size_t olm_inbound_group_session_id(
+ OlmInboundGroupSession *session,
+ uint8_t * id, size_t id_length
+) {
+ if (id_length < olm_inbound_group_session_id_length(session)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ return _olm_encode_base64(
+ session->signing_key.public_key, GROUP_SESSION_ID_LENGTH, id
+ );
+}
+
+uint32_t olm_inbound_group_session_first_known_index(
+ const OlmInboundGroupSession *session
+) {
+ return session->initial_ratchet.counter;
+}
+
+int olm_inbound_group_session_is_verified(
+ const OlmInboundGroupSession *session
+) {
+ return session->signing_key_verified;
+}
+
+size_t olm_export_inbound_group_session_length(
+ const OlmInboundGroupSession *session
+) {
+ return _olm_encode_base64_length(SESSION_EXPORT_RAW_LENGTH);
+}
+
+size_t olm_export_inbound_group_session(
+ OlmInboundGroupSession *session,
+ uint8_t * key, size_t key_length, uint32_t message_index
+) {
+ uint8_t *raw;
+ uint8_t *ptr;
+ Megolm megolm;
+ size_t r;
+ size_t encoded_length = olm_export_inbound_group_session_length(session);
+
+ if (key_length < encoded_length) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ r = _get_megolm(session, message_index, &megolm);
+ if (r == (size_t)-1) {
+ return r;
+ }
+
+ /* put the raw data at the end of the output buffer. */
+ raw = ptr = key + encoded_length - SESSION_EXPORT_RAW_LENGTH;
+ *ptr++ = SESSION_EXPORT_VERSION;
+
+ // Encode message index as a big endian 32-bit number.
+ for (unsigned i = 0; i < 4; i++) {
+ *ptr++ = 0xFF & (message_index >> 24); message_index <<= 8;
+ }
+
+ memcpy(ptr, megolm_get_data(&megolm), MEGOLM_RATCHET_LENGTH);
+ ptr += MEGOLM_RATCHET_LENGTH;
+
+ memcpy(
+ ptr, session->signing_key.public_key,
+ ED25519_PUBLIC_KEY_LENGTH
+ );
+ ptr += ED25519_PUBLIC_KEY_LENGTH;
+
+ return _olm_encode_base64(raw, SESSION_EXPORT_RAW_LENGTH, key);
+}
diff --git a/ext/olm/src/megolm.c b/ext/olm/src/megolm.c
new file mode 100644
index 0000000..c4d1110
--- /dev/null
+++ b/ext/olm/src/megolm.c
@@ -0,0 +1,154 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include "olm/megolm.h"
+
+#include <string.h>
+
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/pickle.h"
+
+static const struct _olm_cipher_aes_sha_256 MEGOLM_CIPHER =
+ OLM_CIPHER_INIT_AES_SHA_256("MEGOLM_KEYS");
+const struct _olm_cipher *megolm_cipher = OLM_CIPHER_BASE(&MEGOLM_CIPHER);
+
+/* the seeds used in the HMAC-SHA-256 functions for each part of the ratchet.
+ */
+#define HASH_KEY_SEED_LENGTH 1
+static uint8_t HASH_KEY_SEEDS[MEGOLM_RATCHET_PARTS][HASH_KEY_SEED_LENGTH] = {
+ {0x00},
+ {0x01},
+ {0x02},
+ {0x03}
+};
+
+static void rehash_part(
+ uint8_t data[MEGOLM_RATCHET_PARTS][MEGOLM_RATCHET_PART_LENGTH],
+ int rehash_from_part, int rehash_to_part
+) {
+ _olm_crypto_hmac_sha256(
+ data[rehash_from_part],
+ MEGOLM_RATCHET_PART_LENGTH,
+ HASH_KEY_SEEDS[rehash_to_part], HASH_KEY_SEED_LENGTH,
+ data[rehash_to_part]
+ );
+}
+
+
+
+void megolm_init(Megolm *megolm, uint8_t const *random_data, uint32_t counter) {
+ megolm->counter = counter;
+ memcpy(megolm->data, random_data, MEGOLM_RATCHET_LENGTH);
+}
+
+size_t megolm_pickle_length(const Megolm *megolm) {
+ size_t length = 0;
+ length += _olm_pickle_bytes_length(megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH);
+ length += _olm_pickle_uint32_length(megolm->counter);
+ return length;
+
+}
+
+uint8_t * megolm_pickle(const Megolm *megolm, uint8_t *pos) {
+ pos = _olm_pickle_bytes(pos, megolm_get_data(megolm), MEGOLM_RATCHET_LENGTH);
+ pos = _olm_pickle_uint32(pos, megolm->counter);
+ return pos;
+}
+
+const uint8_t * megolm_unpickle(Megolm *megolm, const uint8_t *pos,
+ const uint8_t *end) {
+ pos = _olm_unpickle_bytes(pos, end, (uint8_t *)(megolm->data),
+ MEGOLM_RATCHET_LENGTH);
+ UNPICKLE_OK(pos);
+
+ pos = _olm_unpickle_uint32(pos, end, &megolm->counter);
+ UNPICKLE_OK(pos);
+
+ return pos;
+}
+
+/* simplistic implementation for a single step */
+void megolm_advance(Megolm *megolm) {
+ uint32_t mask = 0x00FFFFFF;
+ int h = 0;
+ int i;
+
+ megolm->counter++;
+
+ /* figure out how much we need to rekey */
+ while (h < (int)MEGOLM_RATCHET_PARTS) {
+ if (!(megolm->counter & mask))
+ break;
+ h++;
+ mask >>= 8;
+ }
+
+ /* now update R(h)...R(3) based on R(h) */
+ for (i = MEGOLM_RATCHET_PARTS-1; i >= h; i--) {
+ rehash_part(megolm->data, h, i);
+ }
+}
+
+void megolm_advance_to(Megolm *megolm, uint32_t advance_to) {
+ int j;
+
+ /* starting with R0, see if we need to update each part of the hash */
+ for (j = 0; j < (int)MEGOLM_RATCHET_PARTS; j++) {
+ int shift = (MEGOLM_RATCHET_PARTS-j-1) * 8;
+ uint32_t mask = (~(uint32_t)0) << shift;
+ int k;
+
+ /* how many times do we need to rehash this part?
+ *
+ * '& 0xff' ensures we handle integer wraparound correctly
+ */
+ unsigned int steps =
+ ((advance_to >> shift) - (megolm->counter >> shift)) & 0xff;
+
+ if (steps == 0) {
+ /* deal with the edge case where megolm->counter is slightly larger
+ * than advance_to. This should only happen for R(0), and implies
+ * that advance_to has wrapped around and we need to advance R(0)
+ * 256 times.
+ */
+ if (advance_to < megolm->counter) {
+ steps = 0x100;
+ } else {
+ continue;
+ }
+ }
+
+ /* for all but the last step, we can just bump R(j) without regard
+ * to R(j+1)...R(3).
+ */
+ while (steps > 1) {
+ rehash_part(megolm->data, j, j);
+ steps --;
+ }
+
+ /* on the last step we also need to bump R(j+1)...R(3).
+ *
+ * (Theoretically, we could skip bumping R(j+2) if we're going to bump
+ * R(j+1) again, but the code to figure that out is a bit baroque and
+ * doesn't save us much).
+ */
+ for (k = 3; k >= j; k--) {
+ rehash_part(megolm->data, j, k);
+ }
+ megolm->counter = advance_to & mask;
+ }
+}
diff --git a/ext/olm/src/memory.cpp b/ext/olm/src/memory.cpp
new file mode 100644
index 0000000..20e0683
--- /dev/null
+++ b/ext/olm/src/memory.cpp
@@ -0,0 +1,45 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/memory.hh"
+#include "olm/memory.h"
+
+void _olm_unset(
+ void volatile * buffer, size_t buffer_length
+) {
+ olm::unset(buffer, buffer_length);
+}
+
+void olm::unset(
+ void volatile * buffer, std::size_t buffer_length
+) {
+ char volatile * pos = reinterpret_cast<char volatile *>(buffer);
+ char volatile * end = pos + buffer_length;
+ while (pos != end) {
+ *(pos++) = 0;
+ }
+}
+
+
+bool olm::is_equal(
+ std::uint8_t const * buffer_a,
+ std::uint8_t const * buffer_b,
+ std::size_t length
+) {
+ std::uint8_t volatile result = 0;
+ while (length--) {
+ result |= (*(buffer_a++)) ^ (*(buffer_b++));
+ }
+ return result == 0;
+}
diff --git a/ext/olm/src/message.cpp b/ext/olm/src/message.cpp
new file mode 100644
index 0000000..e5e63f0
--- /dev/null
+++ b/ext/olm/src/message.cpp
@@ -0,0 +1,406 @@
+/* Copyright 2015-2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/message.hh"
+
+#include "olm/memory.hh"
+
+namespace {
+
+template<typename T>
+static std::size_t varint_length(
+ T value
+) {
+ std::size_t result = 1;
+ while (value >= 128U) {
+ ++result;
+ value >>= 7;
+ }
+ return result;
+}
+
+
+template<typename T>
+static std::uint8_t * varint_encode(
+ std::uint8_t * output,
+ T value
+) {
+ while (value >= 128U) {
+ *(output++) = (0x7F & value) | 0x80;
+ value >>= 7;
+ }
+ (*output++) = value;
+ return output;
+}
+
+
+template<typename T>
+static T varint_decode(
+ std::uint8_t const * varint_start,
+ std::uint8_t const * varint_end
+) {
+ T value = 0;
+ if (varint_end == varint_start) {
+ return 0;
+ }
+ do {
+ value <<= 7;
+ value |= 0x7F & *(--varint_end);
+ } while (varint_end != varint_start);
+ return value;
+}
+
+
+static std::uint8_t const * varint_skip(
+ std::uint8_t const * input,
+ std::uint8_t const * input_end
+) {
+ while (input != input_end) {
+ std::uint8_t tmp = *(input++);
+ if ((tmp & 0x80) == 0) {
+ return input;
+ }
+ }
+ return input;
+}
+
+
+static std::size_t varstring_length(
+ std::size_t string_length
+) {
+ return varint_length(string_length) + string_length;
+}
+
+static std::size_t const VERSION_LENGTH = 1;
+static std::uint8_t const RATCHET_KEY_TAG = 012;
+static std::uint8_t const COUNTER_TAG = 020;
+static std::uint8_t const CIPHERTEXT_TAG = 042;
+
+static std::uint8_t * encode(
+ std::uint8_t * pos,
+ std::uint8_t tag,
+ std::uint32_t value
+) {
+ *(pos++) = tag;
+ return varint_encode(pos, value);
+}
+
+static std::uint8_t * encode(
+ std::uint8_t * pos,
+ std::uint8_t tag,
+ std::uint8_t * & value, std::size_t value_length
+) {
+ *(pos++) = tag;
+ pos = varint_encode(pos, value_length);
+ value = pos;
+ return pos + value_length;
+}
+
+static std::uint8_t const * decode(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ std::uint8_t tag,
+ std::uint32_t & value, bool & has_value
+) {
+ if (pos != end && *pos == tag) {
+ ++pos;
+ std::uint8_t const * value_start = pos;
+ pos = varint_skip(pos, end);
+ value = varint_decode<std::uint32_t>(value_start, pos);
+ has_value = true;
+ }
+ return pos;
+}
+
+
+static std::uint8_t const * decode(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ std::uint8_t tag,
+ std::uint8_t const * & value, std::size_t & value_length
+) {
+ if (pos != end && *pos == tag) {
+ ++pos;
+ std::uint8_t const * len_start = pos;
+ pos = varint_skip(pos, end);
+ std::size_t len = varint_decode<std::size_t>(len_start, pos);
+ if (len > std::size_t(end - pos)) return end;
+ value = pos;
+ value_length = len;
+ pos += len;
+ }
+ return pos;
+}
+
+static std::uint8_t const * skip_unknown(
+ std::uint8_t const * pos, std::uint8_t const * end
+) {
+ if (pos != end) {
+ uint8_t tag = *pos;
+ if ((tag & 0x7) == 0) {
+ pos = varint_skip(pos, end);
+ pos = varint_skip(pos, end);
+ } else if ((tag & 0x7) == 2) {
+ pos = varint_skip(pos, end);
+ std::uint8_t const * len_start = pos;
+ pos = varint_skip(pos, end);
+ std::size_t len = varint_decode<std::size_t>(len_start, pos);
+ if (len > std::size_t(end - pos)) return end;
+ pos += len;
+ } else {
+ return end;
+ }
+ }
+ return pos;
+}
+
+} // namespace
+
+
+std::size_t olm::encode_message_length(
+ std::uint32_t counter,
+ std::size_t ratchet_key_length,
+ std::size_t ciphertext_length,
+ std::size_t mac_length
+) {
+ std::size_t length = VERSION_LENGTH;
+ length += 1 + varstring_length(ratchet_key_length);
+ length += 1 + varint_length(counter);
+ length += 1 + varstring_length(ciphertext_length);
+ length += mac_length;
+ return length;
+}
+
+
+void olm::encode_message(
+ olm::MessageWriter & writer,
+ std::uint8_t version,
+ std::uint32_t counter,
+ std::size_t ratchet_key_length,
+ std::size_t ciphertext_length,
+ std::uint8_t * output
+) {
+ std::uint8_t * pos = output;
+ *(pos++) = version;
+ pos = encode(pos, RATCHET_KEY_TAG, writer.ratchet_key, ratchet_key_length);
+ pos = encode(pos, COUNTER_TAG, counter);
+ pos = encode(pos, CIPHERTEXT_TAG, writer.ciphertext, ciphertext_length);
+}
+
+
+void olm::decode_message(
+ olm::MessageReader & reader,
+ std::uint8_t const * input, std::size_t input_length,
+ std::size_t mac_length
+) {
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = input + input_length - mac_length;
+ std::uint8_t const * unknown = nullptr;
+
+ reader.version = 0;
+ reader.has_counter = false;
+ reader.counter = 0;
+ reader.input = input;
+ reader.input_length = input_length;
+ reader.ratchet_key = nullptr;
+ reader.ratchet_key_length = 0;
+ reader.ciphertext = nullptr;
+ reader.ciphertext_length = 0;
+
+ if (input_length < mac_length) return;
+
+ if (pos == end) return;
+ reader.version = *(pos++);
+
+ while (pos != end) {
+ unknown = pos;
+ pos = decode(
+ pos, end, RATCHET_KEY_TAG,
+ reader.ratchet_key, reader.ratchet_key_length
+ );
+ pos = decode(
+ pos, end, COUNTER_TAG,
+ reader.counter, reader.has_counter
+ );
+ pos = decode(
+ pos, end, CIPHERTEXT_TAG,
+ reader.ciphertext, reader.ciphertext_length
+ );
+ if (unknown == pos) {
+ pos = skip_unknown(pos, end);
+ }
+ }
+}
+
+
+namespace {
+
+static std::uint8_t const ONE_TIME_KEY_ID_TAG = 012;
+static std::uint8_t const BASE_KEY_TAG = 022;
+static std::uint8_t const IDENTITY_KEY_TAG = 032;
+static std::uint8_t const MESSAGE_TAG = 042;
+
+} // namespace
+
+
+std::size_t olm::encode_one_time_key_message_length(
+ std::size_t one_time_key_length,
+ std::size_t identity_key_length,
+ std::size_t base_key_length,
+ std::size_t message_length
+) {
+ std::size_t length = VERSION_LENGTH;
+ length += 1 + varstring_length(one_time_key_length);
+ length += 1 + varstring_length(identity_key_length);
+ length += 1 + varstring_length(base_key_length);
+ length += 1 + varstring_length(message_length);
+ return length;
+}
+
+
+void olm::encode_one_time_key_message(
+ olm::PreKeyMessageWriter & writer,
+ std::uint8_t version,
+ std::size_t identity_key_length,
+ std::size_t base_key_length,
+ std::size_t one_time_key_length,
+ std::size_t message_length,
+ std::uint8_t * output
+) {
+ std::uint8_t * pos = output;
+ *(pos++) = version;
+ pos = encode(pos, ONE_TIME_KEY_ID_TAG, writer.one_time_key, one_time_key_length);
+ pos = encode(pos, BASE_KEY_TAG, writer.base_key, base_key_length);
+ pos = encode(pos, IDENTITY_KEY_TAG, writer.identity_key, identity_key_length);
+ pos = encode(pos, MESSAGE_TAG, writer.message, message_length);
+}
+
+
+void olm::decode_one_time_key_message(
+ PreKeyMessageReader & reader,
+ std::uint8_t const * input, std::size_t input_length
+) {
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = input + input_length;
+ std::uint8_t const * unknown = nullptr;
+
+ reader.version = 0;
+ reader.one_time_key = nullptr;
+ reader.one_time_key_length = 0;
+ reader.identity_key = nullptr;
+ reader.identity_key_length = 0;
+ reader.base_key = nullptr;
+ reader.base_key_length = 0;
+ reader.message = nullptr;
+ reader.message_length = 0;
+
+ if (pos == end) return;
+ reader.version = *(pos++);
+
+ while (pos != end) {
+ unknown = pos;
+ pos = decode(
+ pos, end, ONE_TIME_KEY_ID_TAG,
+ reader.one_time_key, reader.one_time_key_length
+ );
+ pos = decode(
+ pos, end, BASE_KEY_TAG,
+ reader.base_key, reader.base_key_length
+ );
+ pos = decode(
+ pos, end, IDENTITY_KEY_TAG,
+ reader.identity_key, reader.identity_key_length
+ );
+ pos = decode(
+ pos, end, MESSAGE_TAG,
+ reader.message, reader.message_length
+ );
+ if (unknown == pos) {
+ pos = skip_unknown(pos, end);
+ }
+ }
+}
+
+
+
+static const std::uint8_t GROUP_MESSAGE_INDEX_TAG = 010;
+static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
+
+size_t _olm_encode_group_message_length(
+ uint32_t message_index,
+ size_t ciphertext_length,
+ size_t mac_length,
+ size_t signature_length
+) {
+ size_t length = VERSION_LENGTH;
+ length += 1 + varint_length(message_index);
+ length += 1 + varstring_length(ciphertext_length);
+ length += mac_length;
+ length += signature_length;
+ return length;
+}
+
+
+size_t _olm_encode_group_message(
+ uint8_t version,
+ uint32_t message_index,
+ size_t ciphertext_length,
+ uint8_t *output,
+ uint8_t **ciphertext_ptr
+) {
+ std::uint8_t * pos = output;
+
+ *(pos++) = version;
+ pos = encode(pos, GROUP_MESSAGE_INDEX_TAG, message_index);
+ pos = encode(pos, GROUP_CIPHERTEXT_TAG, *ciphertext_ptr, ciphertext_length);
+ return pos-output;
+}
+
+void _olm_decode_group_message(
+ const uint8_t *input, size_t input_length,
+ size_t mac_length, size_t signature_length,
+ struct _OlmDecodeGroupMessageResults *results
+) {
+ std::uint8_t const * pos = input;
+ std::size_t trailer_length = mac_length + signature_length;
+ std::uint8_t const * end = input + input_length - trailer_length;
+ std::uint8_t const * unknown = nullptr;
+
+ bool has_message_index = false;
+ results->version = 0;
+ results->message_index = 0;
+ results->has_message_index = (int)has_message_index;
+ results->ciphertext = nullptr;
+ results->ciphertext_length = 0;
+
+ if (input_length < trailer_length) return;
+
+ if (pos == end) return;
+ results->version = *(pos++);
+
+ while (pos != end) {
+ unknown = pos;
+ pos = decode(
+ pos, end, GROUP_MESSAGE_INDEX_TAG,
+ results->message_index, has_message_index
+ );
+ pos = decode(
+ pos, end, GROUP_CIPHERTEXT_TAG,
+ results->ciphertext, results->ciphertext_length
+ );
+ if (unknown == pos) {
+ pos = skip_unknown(pos, end);
+ }
+ }
+
+ results->has_message_index = (int)has_message_index;
+}
diff --git a/ext/olm/src/olm.cpp b/ext/olm/src/olm.cpp
new file mode 100644
index 0000000..3a30f7a
--- /dev/null
+++ b/ext/olm/src/olm.cpp
@@ -0,0 +1,846 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/olm.h"
+#include "olm/session.hh"
+#include "olm/account.hh"
+#include "olm/cipher.h"
+#include "olm/pickle_encoding.h"
+#include "olm/utility.hh"
+#include "olm/base64.hh"
+#include "olm/memory.hh"
+
+#include <new>
+#include <cstring>
+
+namespace {
+
+static OlmAccount * to_c(olm::Account * account) {
+ return reinterpret_cast<OlmAccount *>(account);
+}
+
+static OlmSession * to_c(olm::Session * session) {
+ return reinterpret_cast<OlmSession *>(session);
+}
+
+static OlmUtility * to_c(olm::Utility * utility) {
+ return reinterpret_cast<OlmUtility *>(utility);
+}
+
+static olm::Account * from_c(OlmAccount * account) {
+ return reinterpret_cast<olm::Account *>(account);
+}
+
+static const olm::Account * from_c(OlmAccount const * account) {
+ return reinterpret_cast<olm::Account const *>(account);
+}
+
+static olm::Session * from_c(OlmSession * session) {
+ return reinterpret_cast<olm::Session *>(session);
+}
+
+static const olm::Session * from_c(OlmSession const * session) {
+ return reinterpret_cast<const olm::Session *>(session);
+}
+
+static olm::Utility * from_c(OlmUtility * utility) {
+ return reinterpret_cast<olm::Utility *>(utility);
+}
+
+static const olm::Utility * from_c(OlmUtility const * utility) {
+ return reinterpret_cast<const olm::Utility *>(utility);
+}
+
+static std::uint8_t * from_c(void * bytes) {
+ return reinterpret_cast<std::uint8_t *>(bytes);
+}
+
+static std::uint8_t const * from_c(void const * bytes) {
+ return reinterpret_cast<std::uint8_t const *>(bytes);
+}
+
+std::size_t b64_output_length(
+ size_t raw_length
+) {
+ return olm::encode_base64_length(raw_length);
+}
+
+std::uint8_t * b64_output_pos(
+ std::uint8_t * output,
+ size_t raw_length
+) {
+ return output + olm::encode_base64_length(raw_length) - raw_length;
+}
+
+std::size_t b64_output(
+ std::uint8_t * output, size_t raw_length
+) {
+ std::size_t base64_length = olm::encode_base64_length(raw_length);
+ std::uint8_t * raw_output = output + base64_length - raw_length;
+ olm::encode_base64(raw_output, raw_length, output);
+ return base64_length;
+}
+
+std::size_t b64_input(
+ std::uint8_t * input, size_t b64_length,
+ OlmErrorCode & last_error
+) {
+ std::size_t raw_length = olm::decode_base64_length(b64_length);
+ if (raw_length == std::size_t(-1)) {
+ last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ olm::decode_base64(input, b64_length, input);
+ return raw_length;
+}
+
+} // namespace
+
+
+extern "C" {
+
+void olm_get_library_version(uint8_t *major, uint8_t *minor, uint8_t *patch) {
+ if (major != NULL) *major = OLMLIB_VERSION_MAJOR;
+ if (minor != NULL) *minor = OLMLIB_VERSION_MINOR;
+ if (patch != NULL) *patch = OLMLIB_VERSION_PATCH;
+}
+
+size_t olm_error(void) {
+ return std::size_t(-1);
+}
+
+
+const char * olm_account_last_error(
+ const OlmAccount * account
+) {
+ auto error = from_c(account)->last_error;
+ return _olm_error_to_string(error);
+}
+
+enum OlmErrorCode olm_account_last_error_code(
+ const OlmAccount * account
+) {
+ return from_c(account)->last_error;
+}
+
+const char * olm_session_last_error(
+ const OlmSession * session
+) {
+ auto error = from_c(session)->last_error;
+ return _olm_error_to_string(error);
+}
+
+enum OlmErrorCode olm_session_last_error_code(
+ OlmSession const * session
+) {
+ return from_c(session)->last_error;
+}
+
+const char * olm_utility_last_error(
+ OlmUtility const * utility
+) {
+ auto error = from_c(utility)->last_error;
+ return _olm_error_to_string(error);
+}
+
+enum OlmErrorCode olm_utility_last_error_code(
+ OlmUtility const * utility
+) {
+ return from_c(utility)->last_error;
+}
+
+size_t olm_account_size(void) {
+ return sizeof(olm::Account);
+}
+
+
+size_t olm_session_size(void) {
+ return sizeof(olm::Session);
+}
+
+size_t olm_utility_size(void) {
+ return sizeof(olm::Utility);
+}
+
+OlmAccount * olm_account(
+ void * memory
+) {
+ olm::unset(memory, sizeof(olm::Account));
+ return to_c(new(memory) olm::Account());
+}
+
+
+OlmSession * olm_session(
+ void * memory
+) {
+ olm::unset(memory, sizeof(olm::Session));
+ return to_c(new(memory) olm::Session());
+}
+
+
+OlmUtility * olm_utility(
+ void * memory
+) {
+ olm::unset(memory, sizeof(olm::Utility));
+ return to_c(new(memory) olm::Utility());
+}
+
+
+size_t olm_clear_account(
+ OlmAccount * account
+) {
+ /* Clear the memory backing the account */
+ olm::unset(account, sizeof(olm::Account));
+ /* Initialise a fresh account object in case someone tries to use it */
+ new(account) olm::Account();
+ return sizeof(olm::Account);
+}
+
+
+size_t olm_clear_session(
+ OlmSession * session
+) {
+ /* Clear the memory backing the session */
+ olm::unset(session, sizeof(olm::Session));
+ /* Initialise a fresh session object in case someone tries to use it */
+ new(session) olm::Session();
+ return sizeof(olm::Session);
+}
+
+
+size_t olm_clear_utility(
+ OlmUtility * utility
+) {
+ /* Clear the memory backing the session */
+ olm::unset(utility, sizeof(olm::Utility));
+ /* Initialise a fresh session object in case someone tries to use it */
+ new(utility) olm::Utility();
+ return sizeof(olm::Utility);
+}
+
+
+size_t olm_pickle_account_length(
+ OlmAccount const * account
+) {
+ return _olm_enc_output_length(pickle_length(*from_c(account)));
+}
+
+
+size_t olm_pickle_session_length(
+ OlmSession const * session
+) {
+ return _olm_enc_output_length(pickle_length(*from_c(session)));
+}
+
+
+size_t olm_pickle_account(
+ OlmAccount * account,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ olm::Account & object = *from_c(account);
+ std::size_t raw_length = pickle_length(object);
+ if (pickled_length < _olm_enc_output_length(raw_length)) {
+ object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return size_t(-1);
+ }
+ pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object);
+ return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length);
+}
+
+
+size_t olm_pickle_session(
+ OlmSession * session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ olm::Session & object = *from_c(session);
+ std::size_t raw_length = pickle_length(object);
+ if (pickled_length < _olm_enc_output_length(raw_length)) {
+ object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return size_t(-1);
+ }
+ pickle(_olm_enc_output_pos(from_c(pickled), raw_length), object);
+ return _olm_enc_output(from_c(key), key_length, from_c(pickled), raw_length);
+}
+
+
+size_t olm_unpickle_account(
+ OlmAccount * account,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ olm::Account & object = *from_c(account);
+ std::uint8_t * input = from_c(pickled);
+ std::size_t raw_length = _olm_enc_input(
+ from_c(key), key_length, input, pickled_length, &object.last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = pos + raw_length;
+
+ pos = unpickle(pos, end, object);
+
+ if (!pos) {
+ /* Input was corrupted. */
+ if (object.last_error == OlmErrorCode::OLM_SUCCESS) {
+ object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE;
+ }
+ return std::size_t(-1);
+ } else if (pos != end) {
+ /* Input was longer than expected. */
+ object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA;
+ return std::size_t(-1);
+ }
+
+ return pickled_length;
+}
+
+
+size_t olm_unpickle_session(
+ OlmSession * session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ olm::Session & object = *from_c(session);
+ std::uint8_t * input = from_c(pickled);
+ std::size_t raw_length = _olm_enc_input(
+ from_c(key), key_length, input, pickled_length, &object.last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = pos + raw_length;
+
+ pos = unpickle(pos, end, object);
+
+ if (!pos) {
+ /* Input was corrupted. */
+ if (object.last_error == OlmErrorCode::OLM_SUCCESS) {
+ object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE;
+ }
+ return std::size_t(-1);
+ } else if (pos != end) {
+ /* Input was longer than expected. */
+ object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA;
+ return std::size_t(-1);
+ }
+
+ return pickled_length;
+}
+
+
+size_t olm_create_account_random_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->new_account_random_length();
+}
+
+
+size_t olm_create_account(
+ OlmAccount * account,
+ void * random, size_t random_length
+) {
+ size_t result = from_c(account)->new_account(from_c(random), random_length);
+ olm::unset(random, random_length);
+ return result;
+}
+
+
+size_t olm_account_identity_keys_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->get_identity_json_length();
+}
+
+
+size_t olm_account_identity_keys(
+ OlmAccount * account,
+ void * identity_keys, size_t identity_key_length
+) {
+ return from_c(account)->get_identity_json(
+ from_c(identity_keys), identity_key_length
+ );
+}
+
+
+size_t olm_account_signature_length(
+ OlmAccount const * account
+) {
+ return b64_output_length(from_c(account)->signature_length());
+}
+
+
+size_t olm_account_sign(
+ OlmAccount * account,
+ void const * message, size_t message_length,
+ void * signature, size_t signature_length
+) {
+ std::size_t raw_length = from_c(account)->signature_length();
+ if (signature_length < b64_output_length(raw_length)) {
+ from_c(account)->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ from_c(account)->sign(
+ from_c(message), message_length,
+ b64_output_pos(from_c(signature), raw_length), raw_length
+ );
+ return b64_output(from_c(signature), raw_length);
+}
+
+
+size_t olm_account_one_time_keys_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->get_one_time_keys_json_length();
+}
+
+
+size_t olm_account_one_time_keys(
+ OlmAccount * account,
+ void * one_time_keys_json, size_t one_time_key_json_length
+) {
+ return from_c(account)->get_one_time_keys_json(
+ from_c(one_time_keys_json), one_time_key_json_length
+ );
+}
+
+
+size_t olm_account_mark_keys_as_published(
+ OlmAccount * account
+) {
+ return from_c(account)->mark_keys_as_published();
+}
+
+
+size_t olm_account_max_number_of_one_time_keys(
+ OlmAccount const * account
+) {
+ return from_c(account)->max_number_of_one_time_keys();
+}
+
+
+size_t olm_account_generate_one_time_keys_random_length(
+ OlmAccount const * account,
+ size_t number_of_keys
+) {
+ return from_c(account)->generate_one_time_keys_random_length(number_of_keys);
+}
+
+
+size_t olm_account_generate_one_time_keys(
+ OlmAccount * account,
+ size_t number_of_keys,
+ void * random, size_t random_length
+) {
+ size_t result = from_c(account)->generate_one_time_keys(
+ number_of_keys,
+ from_c(random), random_length
+ );
+ olm::unset(random, random_length);
+ return result;
+}
+
+
+size_t olm_account_generate_fallback_key_random_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->generate_fallback_key_random_length();
+}
+
+
+size_t olm_account_generate_fallback_key(
+ OlmAccount * account,
+ void * random, size_t random_length
+) {
+ size_t result = from_c(account)->generate_fallback_key(
+ from_c(random), random_length
+ );
+ olm::unset(random, random_length);
+ return result;
+}
+
+
+size_t olm_account_fallback_key_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->get_fallback_key_json_length();
+}
+
+
+size_t olm_account_fallback_key(
+ OlmAccount * account,
+ void * fallback_key_json, size_t fallback_key_json_length
+) {
+ return from_c(account)->get_fallback_key_json(
+ from_c(fallback_key_json), fallback_key_json_length
+ );
+}
+
+
+size_t olm_account_unpublished_fallback_key_length(
+ OlmAccount const * account
+) {
+ return from_c(account)->get_unpublished_fallback_key_json_length();
+}
+
+
+size_t olm_account_unpublished_fallback_key(
+ OlmAccount * account,
+ void * fallback_key_json, size_t fallback_key_json_length
+) {
+ return from_c(account)->get_unpublished_fallback_key_json(
+ from_c(fallback_key_json), fallback_key_json_length
+ );
+}
+
+
+void olm_account_forget_old_fallback_key(
+ OlmAccount * account
+) {
+ return from_c(account)->forget_old_fallback_key();
+}
+
+
+size_t olm_create_outbound_session_random_length(
+ OlmSession const * session
+) {
+ return from_c(session)->new_outbound_session_random_length();
+}
+
+
+size_t olm_create_outbound_session(
+ OlmSession * session,
+ OlmAccount const * account,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void const * their_one_time_key, size_t their_one_time_key_length,
+ void * random, size_t random_length
+) {
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::uint8_t const * ot_key = from_c(their_one_time_key);
+ std::size_t id_key_length = their_identity_key_length;
+ std::size_t ot_key_length = their_one_time_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH
+ || olm::decode_base64_length(ot_key_length) != CURVE25519_KEY_LENGTH
+ ) {
+ from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ _olm_curve25519_public_key identity_key;
+ _olm_curve25519_public_key one_time_key;
+
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
+ olm::decode_base64(ot_key, ot_key_length, one_time_key.public_key);
+
+ size_t result = from_c(session)->new_outbound_session(
+ *from_c(account), identity_key, one_time_key,
+ from_c(random), random_length
+ );
+ olm::unset(random, random_length);
+ return result;
+}
+
+
+size_t olm_create_inbound_session(
+ OlmSession * session,
+ OlmAccount * account,
+ void * one_time_key_message, size_t message_length
+) {
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(session)->new_inbound_session(
+ *from_c(account), nullptr, from_c(one_time_key_message), raw_length
+ );
+}
+
+
+size_t olm_create_inbound_session_from(
+ OlmSession * session,
+ OlmAccount * account,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+) {
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::size_t id_key_length = their_identity_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH) {
+ from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ _olm_curve25519_public_key identity_key;
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
+
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(session)->new_inbound_session(
+ *from_c(account), &identity_key,
+ from_c(one_time_key_message), raw_length
+ );
+}
+
+
+size_t olm_session_id_length(
+ OlmSession const * session
+) {
+ return b64_output_length(from_c(session)->session_id_length());
+}
+
+size_t olm_session_id(
+ OlmSession * session,
+ void * id, size_t id_length
+) {
+ std::size_t raw_length = from_c(session)->session_id_length();
+ if (id_length < b64_output_length(raw_length)) {
+ from_c(session)->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::size_t result = from_c(session)->session_id(
+ b64_output_pos(from_c(id), raw_length), raw_length
+ );
+ if (result == std::size_t(-1)) {
+ return result;
+ }
+ return b64_output(from_c(id), raw_length);
+}
+
+
+int olm_session_has_received_message(
+ OlmSession const * session
+) {
+ return from_c(session)->received_message;
+}
+
+void olm_session_describe(
+ OlmSession * session, char *buf, size_t buflen
+) {
+ from_c(session)->describe(buf, buflen);
+}
+
+size_t olm_matches_inbound_session(
+ OlmSession * session,
+ void * one_time_key_message, size_t message_length
+) {
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ bool matches = from_c(session)->matches_inbound_session(
+ nullptr, from_c(one_time_key_message), raw_length
+ );
+ return matches ? 1 : 0;
+}
+
+
+size_t olm_matches_inbound_session_from(
+ OlmSession * session,
+ void const * their_identity_key, size_t their_identity_key_length,
+ void * one_time_key_message, size_t message_length
+) {
+ std::uint8_t const * id_key = from_c(their_identity_key);
+ std::size_t id_key_length = their_identity_key_length;
+
+ if (olm::decode_base64_length(id_key_length) != CURVE25519_KEY_LENGTH) {
+ from_c(session)->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ _olm_curve25519_public_key identity_key;
+ olm::decode_base64(id_key, id_key_length, identity_key.public_key);
+
+ std::size_t raw_length = b64_input(
+ from_c(one_time_key_message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ bool matches = from_c(session)->matches_inbound_session(
+ &identity_key, from_c(one_time_key_message), raw_length
+ );
+ return matches ? 1 : 0;
+}
+
+
+size_t olm_remove_one_time_keys(
+ OlmAccount * account,
+ OlmSession * session
+) {
+ size_t result = from_c(account)->remove_key(
+ from_c(session)->bob_one_time_key
+ );
+ if (result == std::size_t(-1)) {
+ from_c(account)->last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID;
+ }
+ return result;
+}
+
+
+size_t olm_encrypt_message_type(
+ OlmSession const * session
+) {
+ return size_t(from_c(session)->encrypt_message_type());
+}
+
+
+size_t olm_encrypt_random_length(
+ OlmSession const * session
+) {
+ return from_c(session)->encrypt_random_length();
+}
+
+
+size_t olm_encrypt_message_length(
+ OlmSession const * session,
+ size_t plaintext_length
+) {
+ return b64_output_length(
+ from_c(session)->encrypt_message_length(plaintext_length)
+ );
+}
+
+
+size_t olm_encrypt(
+ OlmSession * session,
+ void const * plaintext, size_t plaintext_length,
+ void * random, size_t random_length,
+ void * message, size_t message_length
+) {
+ std::size_t raw_length = from_c(session)->encrypt_message_length(
+ plaintext_length
+ );
+ if (message_length < b64_output_length(raw_length)) {
+ from_c(session)->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::size_t result = from_c(session)->encrypt(
+ from_c(plaintext), plaintext_length,
+ from_c(random), random_length,
+ b64_output_pos(from_c(message), raw_length), raw_length
+ );
+ olm::unset(random, random_length);
+ if (result == std::size_t(-1)) {
+ return result;
+ }
+ return b64_output(from_c(message), raw_length);
+}
+
+
+size_t olm_decrypt_max_plaintext_length(
+ OlmSession * session,
+ size_t message_type,
+ void * message, size_t message_length
+) {
+ std::size_t raw_length = b64_input(
+ from_c(message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(session)->decrypt_max_plaintext_length(
+ olm::MessageType(message_type), from_c(message), raw_length
+ );
+}
+
+
+size_t olm_decrypt(
+ OlmSession * session,
+ size_t message_type,
+ void * message, size_t message_length,
+ void * plaintext, size_t max_plaintext_length
+) {
+ std::size_t raw_length = b64_input(
+ from_c(message), message_length, from_c(session)->last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(session)->decrypt(
+ olm::MessageType(message_type), from_c(message), raw_length,
+ from_c(plaintext), max_plaintext_length
+ );
+}
+
+
+size_t olm_sha256_length(
+ OlmUtility const * utility
+) {
+ return b64_output_length(from_c(utility)->sha256_length());
+}
+
+
+size_t olm_sha256(
+ OlmUtility * utility,
+ void const * input, size_t input_length,
+ void * output, size_t output_length
+) {
+ std::size_t raw_length = from_c(utility)->sha256_length();
+ if (output_length < b64_output_length(raw_length)) {
+ from_c(utility)->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::size_t result = from_c(utility)->sha256(
+ from_c(input), input_length,
+ b64_output_pos(from_c(output), raw_length), raw_length
+ );
+ if (result == std::size_t(-1)) {
+ return result;
+ }
+ return b64_output(from_c(output), raw_length);
+}
+
+
+size_t olm_ed25519_verify(
+ OlmUtility * utility,
+ void const * key, size_t key_length,
+ void const * message, size_t message_length,
+ void * signature, size_t signature_length
+) {
+ if (olm::decode_base64_length(key_length) != CURVE25519_KEY_LENGTH) {
+ from_c(utility)->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+ _olm_ed25519_public_key verify_key;
+ olm::decode_base64(from_c(key), key_length, verify_key.public_key);
+ std::size_t raw_signature_length = b64_input(
+ from_c(signature), signature_length, from_c(utility)->last_error
+ );
+ if (raw_signature_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+ return from_c(utility)->ed25519_verify(
+ verify_key,
+ from_c(message), message_length,
+ from_c(signature), raw_signature_length
+ );
+}
+
+}
diff --git a/ext/olm/src/outbound_group_session.c b/ext/olm/src/outbound_group_session.c
new file mode 100644
index 0000000..cbbba9c
--- /dev/null
+++ b/ext/olm/src/outbound_group_session.c
@@ -0,0 +1,390 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/outbound_group_session.h"
+
+#include <string.h>
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/error.h"
+#include "olm/megolm.h"
+#include "olm/memory.h"
+#include "olm/message.h"
+#include "olm/pickle.h"
+#include "olm/pickle_encoding.h"
+
+#define OLM_PROTOCOL_VERSION 3
+#define GROUP_SESSION_ID_LENGTH ED25519_PUBLIC_KEY_LENGTH
+#define PICKLE_VERSION 1
+#define SESSION_KEY_VERSION 2
+
+struct OlmOutboundGroupSession {
+ /** the Megolm ratchet providing the encryption keys */
+ Megolm ratchet;
+
+ /** The ed25519 keypair used for signing the messages */
+ struct _olm_ed25519_key_pair signing_key;
+
+ enum OlmErrorCode last_error;
+};
+
+
+size_t olm_outbound_group_session_size(void) {
+ return sizeof(OlmOutboundGroupSession);
+}
+
+OlmOutboundGroupSession * olm_outbound_group_session(
+ void *memory
+) {
+ OlmOutboundGroupSession *session = memory;
+ olm_clear_outbound_group_session(session);
+ return session;
+}
+
+const char *olm_outbound_group_session_last_error(
+ const OlmOutboundGroupSession *session
+) {
+ return _olm_error_to_string(session->last_error);
+}
+
+enum OlmErrorCode olm_outbound_group_session_last_error_code(
+ const OlmOutboundGroupSession *session
+) {
+ return session->last_error;
+}
+
+size_t olm_clear_outbound_group_session(
+ OlmOutboundGroupSession *session
+) {
+ _olm_unset(session, sizeof(OlmOutboundGroupSession));
+ return sizeof(OlmOutboundGroupSession);
+}
+
+static size_t raw_pickle_length(
+ const OlmOutboundGroupSession *session
+) {
+ size_t length = 0;
+ length += _olm_pickle_uint32_length(PICKLE_VERSION);
+ length += megolm_pickle_length(&(session->ratchet));
+ length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key));
+ return length;
+}
+
+size_t olm_pickle_outbound_group_session_length(
+ const OlmOutboundGroupSession *session
+) {
+ return _olm_enc_output_length(raw_pickle_length(session));
+}
+
+size_t olm_pickle_outbound_group_session(
+ OlmOutboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ size_t raw_length = raw_pickle_length(session);
+ uint8_t *pos;
+
+ if (pickled_length < _olm_enc_output_length(raw_length)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+#ifndef OLM_FUZZING
+ pos = _olm_enc_output_pos(pickled, raw_length);
+#else
+ pos = pickled;
+#endif
+
+ pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
+ pos = megolm_pickle(&(session->ratchet), pos);
+ pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key));
+
+#ifndef OLM_FUZZING
+ return _olm_enc_output(key, key_length, pickled, raw_length);
+#else
+ return raw_length;
+#endif
+}
+
+size_t olm_unpickle_outbound_group_session(
+ OlmOutboundGroupSession *session,
+ void const * key, size_t key_length,
+ void * pickled, size_t pickled_length
+) {
+ const uint8_t *pos;
+ const uint8_t *end;
+ uint32_t pickle_version;
+
+#ifndef OLM_FUZZING
+ size_t raw_length = _olm_enc_input(
+ key, key_length, pickled, pickled_length, &(session->last_error)
+ );
+#else
+ size_t raw_length = pickled_length;
+#endif
+
+ if (raw_length == (size_t)-1) {
+ return raw_length;
+ }
+
+ pos = pickled;
+ end = pos + raw_length;
+
+ pos = _olm_unpickle_uint32(pos, end, &pickle_version);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ if (pickle_version != PICKLE_VERSION) {
+ session->last_error = OLM_UNKNOWN_PICKLE_VERSION;
+ return (size_t)-1;
+ }
+
+ pos = megolm_unpickle(&(session->ratchet), pos, end);
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key));
+ FAIL_ON_CORRUPTED_PICKLE(pos, session);
+
+ if (pos != end) {
+ /* Input was longer than expected. */
+ session->last_error = OLM_PICKLE_EXTRA_DATA;
+ return (size_t)-1;
+ }
+
+ return pickled_length;
+}
+
+
+size_t olm_init_outbound_group_session_random_length(
+ const OlmOutboundGroupSession *session
+) {
+ /* we need data to initialize the megolm ratchet, plus some more for the
+ * session id.
+ */
+ return MEGOLM_RATCHET_LENGTH +
+ ED25519_RANDOM_LENGTH;
+}
+
+size_t olm_init_outbound_group_session(
+ OlmOutboundGroupSession *session,
+ uint8_t *random, size_t random_length
+) {
+ const uint8_t *random_ptr = random;
+
+ if (random_length < olm_init_outbound_group_session_random_length(session)) {
+ /* Insufficient random data for new session */
+ session->last_error = OLM_NOT_ENOUGH_RANDOM;
+ return (size_t)-1;
+ }
+
+ megolm_init(&(session->ratchet), random_ptr, 0);
+ random_ptr += MEGOLM_RATCHET_LENGTH;
+
+ _olm_crypto_ed25519_generate_key(random_ptr, &(session->signing_key));
+ random_ptr += ED25519_RANDOM_LENGTH;
+
+ _olm_unset(random, random_length);
+ return 0;
+}
+
+static size_t raw_message_length(
+ OlmOutboundGroupSession *session,
+ size_t plaintext_length)
+{
+ size_t ciphertext_length, mac_length;
+
+ ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length(
+ megolm_cipher, plaintext_length
+ );
+
+ mac_length = megolm_cipher->ops->mac_length(megolm_cipher);
+
+ return _olm_encode_group_message_length(
+ session->ratchet.counter,
+ ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH
+ );
+}
+
+size_t olm_group_encrypt_message_length(
+ OlmOutboundGroupSession *session,
+ size_t plaintext_length
+) {
+ size_t message_length = raw_message_length(session, plaintext_length);
+ return _olm_encode_base64_length(message_length);
+}
+
+/** write an un-base64-ed message to the buffer */
+static size_t _encrypt(
+ OlmOutboundGroupSession *session, uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * buffer
+) {
+ size_t ciphertext_length, mac_length, message_length;
+ size_t result;
+ uint8_t *ciphertext_ptr;
+
+ ciphertext_length = megolm_cipher->ops->encrypt_ciphertext_length(
+ megolm_cipher,
+ plaintext_length
+ );
+
+ mac_length = megolm_cipher->ops->mac_length(megolm_cipher);
+
+ /* first we build the message structure, then we encrypt
+ * the plaintext into it.
+ */
+ message_length = _olm_encode_group_message(
+ OLM_PROTOCOL_VERSION,
+ session->ratchet.counter,
+ ciphertext_length,
+ buffer,
+ &ciphertext_ptr);
+
+ message_length += mac_length;
+
+ result = megolm_cipher->ops->encrypt(
+ megolm_cipher,
+ megolm_get_data(&(session->ratchet)), MEGOLM_RATCHET_LENGTH,
+ plaintext, plaintext_length,
+ ciphertext_ptr, ciphertext_length,
+ buffer, message_length
+ );
+
+ if (result == (size_t)-1) {
+ return result;
+ }
+
+ megolm_advance(&(session->ratchet));
+
+ /* sign the whole thing with the ed25519 key. */
+ _olm_crypto_ed25519_sign(
+ &(session->signing_key),
+ buffer, message_length,
+ buffer + message_length
+ );
+
+ return result;
+}
+
+size_t olm_group_encrypt(
+ OlmOutboundGroupSession *session,
+ uint8_t const * plaintext, size_t plaintext_length,
+ uint8_t * message, size_t max_message_length
+) {
+ size_t rawmsglen;
+ size_t result;
+ uint8_t *message_pos;
+
+ rawmsglen = raw_message_length(session, plaintext_length);
+
+ if (max_message_length < _olm_encode_base64_length(rawmsglen)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ /* we construct the message at the end of the buffer, so that
+ * we have room to base64-encode it once we're done.
+ */
+ message_pos = message + _olm_encode_base64_length(rawmsglen) - rawmsglen;
+
+ /* write the message, and encrypt it, at message_pos */
+ result = _encrypt(session, plaintext, plaintext_length, message_pos);
+ if (result == (size_t)-1) {
+ return result;
+ }
+
+ /* bas64-encode it */
+ return _olm_encode_base64(
+ message_pos, rawmsglen, message
+ );
+}
+
+
+size_t olm_outbound_group_session_id_length(
+ const OlmOutboundGroupSession *session
+) {
+ return _olm_encode_base64_length(GROUP_SESSION_ID_LENGTH);
+}
+
+size_t olm_outbound_group_session_id(
+ OlmOutboundGroupSession *session,
+ uint8_t * id, size_t id_length
+) {
+ if (id_length < olm_outbound_group_session_id_length(session)) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ return _olm_encode_base64(
+ session->signing_key.public_key.public_key, GROUP_SESSION_ID_LENGTH, id
+ );
+}
+
+uint32_t olm_outbound_group_session_message_index(
+ OlmOutboundGroupSession *session
+) {
+ return session->ratchet.counter;
+}
+
+#define SESSION_KEY_RAW_LENGTH \
+ (1 + 4 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH\
+ + ED25519_SIGNATURE_LENGTH)
+
+size_t olm_outbound_group_session_key_length(
+ const OlmOutboundGroupSession *session
+) {
+ return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH);
+}
+
+size_t olm_outbound_group_session_key(
+ OlmOutboundGroupSession *session,
+ uint8_t * key, size_t key_length
+) {
+ uint8_t *raw;
+ uint8_t *ptr;
+ size_t encoded_length = olm_outbound_group_session_key_length(session);
+
+ if (key_length < encoded_length) {
+ session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ /* put the raw data at the end of the output buffer. */
+ raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH;
+ *ptr++ = SESSION_KEY_VERSION;
+
+ uint32_t counter = session->ratchet.counter;
+ // Encode counter as a big endian 32-bit number.
+ for (unsigned i = 0; i < 4; i++) {
+ *ptr++ = 0xFF & (counter >> 24); counter <<= 8;
+ }
+
+ memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH);
+ ptr += MEGOLM_RATCHET_LENGTH;
+
+ memcpy(
+ ptr, session->signing_key.public_key.public_key,
+ ED25519_PUBLIC_KEY_LENGTH
+ );
+ ptr += ED25519_PUBLIC_KEY_LENGTH;
+
+ /* sign the whole thing with the ed25519 key. */
+ _olm_crypto_ed25519_sign(
+ &(session->signing_key),
+ raw, ptr - raw, ptr
+ );
+
+ return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key);
+}
diff --git a/ext/olm/src/pickle.cpp b/ext/olm/src/pickle.cpp
new file mode 100644
index 0000000..3bffb36
--- /dev/null
+++ b/ext/olm/src/pickle.cpp
@@ -0,0 +1,274 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/pickle.hh"
+#include "olm/pickle.h"
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ std::uint32_t value
+) {
+ pos += 4;
+ for (unsigned i = 4; i--;) { *(--pos) = value; value >>= 8; }
+ return pos + 4;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ std::uint32_t & value
+) {
+ value = 0;
+ if (!pos || end < pos + 4) return nullptr;
+ for (unsigned i = 4; i--;) { value <<= 8; value |= *(pos++); }
+ return pos;
+}
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ std::uint8_t value
+) {
+ *(pos++) = value;
+ return pos;
+}
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ std::uint8_t & value
+) {
+ if (!pos || pos == end) return nullptr;
+ value = *(pos++);
+ return pos;
+}
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ bool value
+) {
+ *(pos++) = value ? 1 : 0;
+ return pos;
+}
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ bool & value
+) {
+ if (!pos || pos == end) return nullptr;
+ value = *(pos++);
+ return pos;
+}
+
+std::uint8_t * olm::pickle_bytes(
+ std::uint8_t * pos,
+ std::uint8_t const * bytes, std::size_t bytes_length
+) {
+ std::memcpy(pos, bytes, bytes_length);
+ return pos + bytes_length;
+}
+
+std::uint8_t const * olm::unpickle_bytes(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ std::uint8_t * bytes, std::size_t bytes_length
+) {
+ if (!pos || end < pos + bytes_length) return nullptr;
+ std::memcpy(bytes, pos, bytes_length);
+ return pos + bytes_length;
+}
+
+
+std::size_t olm::pickle_length(
+ const _olm_curve25519_public_key & value
+) {
+ return sizeof(value.public_key);
+}
+
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ const _olm_curve25519_public_key & value
+) {
+ pos = olm::pickle_bytes(
+ pos, value.public_key, sizeof(value.public_key)
+ );
+ return pos;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ _olm_curve25519_public_key & value
+) {
+ return olm::unpickle_bytes(
+ pos, end, value.public_key, sizeof(value.public_key)
+ );
+}
+
+
+std::size_t olm::pickle_length(
+ const _olm_curve25519_key_pair & value
+) {
+ return sizeof(value.public_key.public_key)
+ + sizeof(value.private_key.private_key);
+}
+
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ const _olm_curve25519_key_pair & value
+) {
+ pos = olm::pickle_bytes(
+ pos, value.public_key.public_key,
+ sizeof(value.public_key.public_key)
+ );
+ pos = olm::pickle_bytes(
+ pos, value.private_key.private_key,
+ sizeof(value.private_key.private_key)
+ );
+ return pos;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ _olm_curve25519_key_pair & value
+) {
+ pos = olm::unpickle_bytes(
+ pos, end, value.public_key.public_key,
+ sizeof(value.public_key.public_key)
+ );
+ if (!pos) return nullptr;
+
+ pos = olm::unpickle_bytes(
+ pos, end, value.private_key.private_key,
+ sizeof(value.private_key.private_key)
+ );
+ if (!pos) return nullptr;
+
+ return pos;
+}
+
+////// pickle.h implementations
+
+std::size_t _olm_pickle_ed25519_public_key_length(
+ const _olm_ed25519_public_key * value
+) {
+ return sizeof(value->public_key);
+}
+
+
+std::uint8_t * _olm_pickle_ed25519_public_key(
+ std::uint8_t * pos,
+ const _olm_ed25519_public_key *value
+) {
+ return olm::pickle_bytes(
+ pos, value->public_key, sizeof(value->public_key)
+ );
+}
+
+
+std::uint8_t const * _olm_unpickle_ed25519_public_key(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ _olm_ed25519_public_key * value
+) {
+ return olm::unpickle_bytes(
+ pos, end, value->public_key, sizeof(value->public_key)
+ );
+}
+
+
+std::size_t _olm_pickle_ed25519_key_pair_length(
+ const _olm_ed25519_key_pair *value
+) {
+ return sizeof(value->public_key.public_key)
+ + sizeof(value->private_key.private_key);
+}
+
+
+std::uint8_t * _olm_pickle_ed25519_key_pair(
+ std::uint8_t * pos,
+ const _olm_ed25519_key_pair *value
+) {
+ pos = olm::pickle_bytes(
+ pos, value->public_key.public_key,
+ sizeof(value->public_key.public_key)
+ );
+ pos = olm::pickle_bytes(
+ pos, value->private_key.private_key,
+ sizeof(value->private_key.private_key)
+ );
+ return pos;
+}
+
+
+std::uint8_t const * _olm_unpickle_ed25519_key_pair(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ _olm_ed25519_key_pair *value
+) {
+ pos = olm::unpickle_bytes(
+ pos, end, value->public_key.public_key,
+ sizeof(value->public_key.public_key)
+ );
+ if (!pos) return nullptr;
+
+ pos = olm::unpickle_bytes(
+ pos, end, value->private_key.private_key,
+ sizeof(value->private_key.private_key)
+ );
+ if (!pos) return nullptr;
+
+ return pos;
+}
+
+uint8_t * _olm_pickle_uint32(uint8_t * pos, uint32_t value) {
+ return olm::pickle(pos, value);
+}
+
+uint8_t const * _olm_unpickle_uint32(
+ uint8_t const * pos, uint8_t const * end,
+ uint32_t *value
+) {
+ return olm::unpickle(pos, end, *value);
+}
+
+uint8_t * _olm_pickle_uint8(uint8_t * pos, uint8_t value) {
+ return olm::pickle(pos, value);
+}
+
+uint8_t const * _olm_unpickle_uint8(
+ uint8_t const * pos, uint8_t const * end,
+ uint8_t *value
+) {
+ return olm::unpickle(pos, end, *value);
+}
+
+uint8_t * _olm_pickle_bool(uint8_t * pos, int value) {
+ return olm::pickle(pos, (bool)value);
+}
+
+uint8_t const * _olm_unpickle_bool(
+ uint8_t const * pos, uint8_t const * end,
+ int *value
+) {
+ return olm::unpickle(pos, end, *reinterpret_cast<bool *>(value));
+}
+
+uint8_t * _olm_pickle_bytes(uint8_t * pos, uint8_t const * bytes,
+ size_t bytes_length) {
+ return olm::pickle_bytes(pos, bytes, bytes_length);
+}
+
+uint8_t const * _olm_unpickle_bytes(uint8_t const * pos, uint8_t const * end,
+ uint8_t * bytes, size_t bytes_length) {
+ return olm::unpickle_bytes(pos, end, bytes, bytes_length);
+}
diff --git a/ext/olm/src/pickle_encoding.c b/ext/olm/src/pickle_encoding.c
new file mode 100644
index 0000000..a56e9e3
--- /dev/null
+++ b/ext/olm/src/pickle_encoding.c
@@ -0,0 +1,92 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/pickle_encoding.h"
+
+#include "olm/base64.h"
+#include "olm/cipher.h"
+#include "olm/olm.h"
+
+static const struct _olm_cipher_aes_sha_256 PICKLE_CIPHER =
+ OLM_CIPHER_INIT_AES_SHA_256("Pickle");
+
+size_t _olm_enc_output_length(
+ size_t raw_length
+) {
+ const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+ size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+ length += cipher->ops->mac_length(cipher);
+ return _olm_encode_base64_length(length);
+}
+
+uint8_t * _olm_enc_output_pos(
+ uint8_t * output,
+ size_t raw_length
+) {
+ const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+ size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
+ length += cipher->ops->mac_length(cipher);
+ return output + _olm_encode_base64_length(length) - length;
+}
+
+size_t _olm_enc_output(
+ uint8_t const * key, size_t key_length,
+ uint8_t * output, size_t raw_length
+) {
+ const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+ size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
+ cipher, raw_length
+ );
+ size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
+ size_t base64_length = _olm_encode_base64_length(length);
+ uint8_t * raw_output = output + base64_length - length;
+ cipher->ops->encrypt(
+ cipher,
+ key, key_length,
+ raw_output, raw_length,
+ raw_output, ciphertext_length,
+ raw_output, length
+ );
+ _olm_encode_base64(raw_output, length, output);
+ return base64_length;
+}
+
+
+size_t _olm_enc_input(uint8_t const * key, size_t key_length,
+ uint8_t * input, size_t b64_length,
+ enum OlmErrorCode * last_error
+) {
+ size_t enc_length = _olm_decode_base64_length(b64_length);
+ if (enc_length == (size_t)-1) {
+ if (last_error) {
+ *last_error = OLM_INVALID_BASE64;
+ }
+ return (size_t)-1;
+ }
+ _olm_decode_base64(input, b64_length, input);
+ const struct _olm_cipher *cipher = OLM_CIPHER_BASE(&PICKLE_CIPHER);
+ size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
+ size_t result = cipher->ops->decrypt(
+ cipher,
+ key, key_length,
+ input, enc_length,
+ input, raw_length,
+ input, raw_length
+ );
+ if (result == (size_t)-1 && last_error) {
+ *last_error = OLM_BAD_ACCOUNT_KEY;
+ }
+ return result;
+}
diff --git a/ext/olm/src/pk.cpp b/ext/olm/src/pk.cpp
new file mode 100644
index 0000000..9217c48
--- /dev/null
+++ b/ext/olm/src/pk.cpp
@@ -0,0 +1,542 @@
+/* Copyright 2018, 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/pk.h"
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/ratchet.hh"
+#include "olm/error.h"
+#include "olm/memory.hh"
+#include "olm/base64.hh"
+#include "olm/pickle_encoding.h"
+#include "olm/pickle.hh"
+
+static const std::size_t MAC_LENGTH = 8;
+
+const struct _olm_cipher_aes_sha_256 olm_pk_cipher_aes_sha256 =
+ OLM_CIPHER_INIT_AES_SHA_256("");
+const struct _olm_cipher *olm_pk_cipher =
+ OLM_CIPHER_BASE(&olm_pk_cipher_aes_sha256);
+
+extern "C" {
+
+struct OlmPkEncryption {
+ OlmErrorCode last_error;
+ _olm_curve25519_public_key recipient_key;
+};
+
+const char * olm_pk_encryption_last_error(
+ const OlmPkEncryption * encryption
+) {
+ auto error = encryption->last_error;
+ return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_encryption_last_error_code(
+ const OlmPkEncryption * encryption
+) {
+ return encryption->last_error;
+}
+
+size_t olm_pk_encryption_size(void) {
+ return sizeof(OlmPkEncryption);
+}
+
+OlmPkEncryption *olm_pk_encryption(
+ void * memory
+) {
+ olm::unset(memory, sizeof(OlmPkEncryption));
+ return new(memory) OlmPkEncryption;
+}
+
+size_t olm_clear_pk_encryption(
+ OlmPkEncryption *encryption
+) {
+ /* Clear the memory backing the encryption */
+ olm::unset(encryption, sizeof(OlmPkEncryption));
+ /* Initialise a fresh encryption object in case someone tries to use it */
+ new(encryption) OlmPkEncryption();
+ return sizeof(OlmPkEncryption);
+}
+
+size_t olm_pk_encryption_set_recipient_key (
+ OlmPkEncryption *encryption,
+ void const * key, size_t key_length
+) {
+ if (key_length < olm_pk_key_length()) {
+ encryption->last_error =
+ OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ olm::decode_base64(
+ (const uint8_t*)key,
+ olm_pk_key_length(),
+ (uint8_t *)encryption->recipient_key.public_key
+ );
+
+ return 0;
+}
+
+size_t olm_pk_ciphertext_length(
+ const OlmPkEncryption *encryption,
+ size_t plaintext_length
+) {
+ return olm::encode_base64_length(
+ _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length)
+ );
+}
+
+size_t olm_pk_mac_length(
+ const OlmPkEncryption *encryption
+) {
+ return olm::encode_base64_length(_olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher));
+}
+
+size_t olm_pk_encrypt_random_length(
+ const OlmPkEncryption *encryption
+) {
+ return CURVE25519_KEY_LENGTH;
+}
+
+size_t olm_pk_encrypt(
+ OlmPkEncryption *encryption,
+ void const * plaintext, size_t plaintext_length,
+ void * ciphertext, size_t ciphertext_length,
+ void * mac, size_t mac_length,
+ void * ephemeral_key, size_t ephemeral_key_size,
+ const void * random, size_t random_length
+) {
+ if (ciphertext_length
+ < olm_pk_ciphertext_length(encryption, plaintext_length)
+ || mac_length
+ < _olm_cipher_aes_sha_256_ops.mac_length(olm_pk_cipher)
+ || ephemeral_key_size
+ < olm_pk_key_length()) {
+ encryption->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ if (random_length < olm_pk_encrypt_random_length(encryption)) {
+ encryption->last_error =
+ OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+
+ _olm_curve25519_key_pair ephemeral_keypair;
+ _olm_crypto_curve25519_generate_key((const uint8_t *) random, &ephemeral_keypair);
+ olm::encode_base64(
+ (const uint8_t *)ephemeral_keypair.public_key.public_key,
+ CURVE25519_KEY_LENGTH,
+ (uint8_t *)ephemeral_key
+ );
+
+ olm::SharedKey secret;
+ _olm_crypto_curve25519_shared_secret(&ephemeral_keypair, &encryption->recipient_key, secret);
+ size_t raw_ciphertext_length =
+ _olm_cipher_aes_sha_256_ops.encrypt_ciphertext_length(olm_pk_cipher, plaintext_length);
+ uint8_t *ciphertext_pos = (uint8_t *) ciphertext + ciphertext_length - raw_ciphertext_length;
+ uint8_t raw_mac[MAC_LENGTH];
+ size_t result = _olm_cipher_aes_sha_256_ops.encrypt(
+ olm_pk_cipher,
+ secret, sizeof(secret),
+ (const uint8_t *) plaintext, plaintext_length,
+ (uint8_t *) ciphertext_pos, raw_ciphertext_length,
+ (uint8_t *) raw_mac, MAC_LENGTH
+ );
+ if (result != std::size_t(-1)) {
+ olm::encode_base64(raw_mac, MAC_LENGTH, (uint8_t *)mac);
+ olm::encode_base64(ciphertext_pos, raw_ciphertext_length, (uint8_t *)ciphertext);
+ }
+ return result;
+}
+
+struct OlmPkDecryption {
+ OlmErrorCode last_error;
+ _olm_curve25519_key_pair key_pair;
+};
+
+const char * olm_pk_decryption_last_error(
+ const OlmPkDecryption * decryption
+) {
+ auto error = decryption->last_error;
+ return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_decryption_last_error_code(
+ const OlmPkDecryption * decryption
+) {
+ return decryption->last_error;
+}
+
+size_t olm_pk_decryption_size(void) {
+ return sizeof(OlmPkDecryption);
+}
+
+OlmPkDecryption *olm_pk_decryption(
+ void * memory
+) {
+ olm::unset(memory, sizeof(OlmPkDecryption));
+ return new(memory) OlmPkDecryption;
+}
+
+size_t olm_clear_pk_decryption(
+ OlmPkDecryption *decryption
+) {
+ /* Clear the memory backing the decryption */
+ olm::unset(decryption, sizeof(OlmPkDecryption));
+ /* Initialise a fresh decryption object in case someone tries to use it */
+ new(decryption) OlmPkDecryption();
+ return sizeof(OlmPkDecryption);
+}
+
+size_t olm_pk_private_key_length(void) {
+ return CURVE25519_KEY_LENGTH;
+}
+
+size_t olm_pk_generate_key_random_length(void) {
+ return olm_pk_private_key_length();
+}
+
+size_t olm_pk_key_length(void) {
+ return olm::encode_base64_length(CURVE25519_KEY_LENGTH);
+}
+
+size_t olm_pk_key_from_private(
+ OlmPkDecryption * decryption,
+ void * pubkey, size_t pubkey_length,
+ const void * privkey, size_t privkey_length
+) {
+ if (pubkey_length < olm_pk_key_length()) {
+ decryption->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ if (privkey_length < olm_pk_private_key_length()) {
+ decryption->last_error =
+ OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ _olm_crypto_curve25519_generate_key((const uint8_t *) privkey, &decryption->key_pair);
+ olm::encode_base64(
+ (const uint8_t *)decryption->key_pair.public_key.public_key,
+ CURVE25519_KEY_LENGTH,
+ (uint8_t *)pubkey
+ );
+ return 0;
+}
+
+size_t olm_pk_generate_key(
+ OlmPkDecryption * decryption,
+ void * pubkey, size_t pubkey_length,
+ const void * privkey, size_t privkey_length
+) {
+ return olm_pk_key_from_private(decryption, pubkey, pubkey_length, privkey, privkey_length);
+}
+
+namespace {
+ static const std::uint32_t PK_DECRYPTION_PICKLE_VERSION = 1;
+
+ static std::size_t pickle_length(
+ OlmPkDecryption const & value
+ ) {
+ std::size_t length = 0;
+ length += olm::pickle_length(PK_DECRYPTION_PICKLE_VERSION);
+ length += olm::pickle_length(value.key_pair);
+ return length;
+ }
+
+
+ static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ OlmPkDecryption const & value
+ ) {
+ pos = olm::pickle(pos, PK_DECRYPTION_PICKLE_VERSION);
+ pos = olm::pickle(pos, value.key_pair);
+ return pos;
+ }
+
+
+ static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ OlmPkDecryption & value
+ ) {
+ uint32_t pickle_version;
+ pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos);
+
+ switch (pickle_version) {
+ case 1:
+ break;
+
+ default:
+ value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION;
+ return nullptr;
+ }
+
+ pos = olm::unpickle(pos, end, value.key_pair); UNPICKLE_OK(pos);
+
+ return pos;
+ }
+}
+
+size_t olm_pickle_pk_decryption_length(
+ const OlmPkDecryption * decryption
+) {
+ return _olm_enc_output_length(pickle_length(*decryption));
+}
+
+size_t olm_pickle_pk_decryption(
+ OlmPkDecryption * decryption,
+ void const * key, size_t key_length,
+ void *pickled, size_t pickled_length
+) {
+ OlmPkDecryption & object = *decryption;
+ std::size_t raw_length = pickle_length(object);
+ if (pickled_length < _olm_enc_output_length(raw_length)) {
+ object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ pickle(_olm_enc_output_pos(reinterpret_cast<std::uint8_t *>(pickled), raw_length), object);
+ return _olm_enc_output(
+ reinterpret_cast<std::uint8_t const *>(key), key_length,
+ reinterpret_cast<std::uint8_t *>(pickled), raw_length
+ );
+}
+
+size_t olm_unpickle_pk_decryption(
+ OlmPkDecryption * decryption,
+ void const * key, size_t key_length,
+ void *pickled, size_t pickled_length,
+ void *pubkey, size_t pubkey_length
+) {
+ OlmPkDecryption & object = *decryption;
+ if (pubkey != NULL && pubkey_length < olm_pk_key_length()) {
+ object.last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::uint8_t * const input = reinterpret_cast<std::uint8_t *>(pickled);
+ std::size_t raw_length = _olm_enc_input(
+ reinterpret_cast<std::uint8_t const *>(key), key_length,
+ input, pickled_length, &object.last_error
+ );
+ if (raw_length == std::size_t(-1)) {
+ return std::size_t(-1);
+ }
+
+ std::uint8_t const * pos = input;
+ std::uint8_t const * end = pos + raw_length;
+
+ pos = unpickle(pos, end, object);
+
+ if (!pos) {
+ /* Input was corrupted. */
+ if (object.last_error == OlmErrorCode::OLM_SUCCESS) {
+ object.last_error = OlmErrorCode::OLM_CORRUPTED_PICKLE;
+ }
+ return std::size_t(-1);
+ } else if (pos != end) {
+ /* Input was longer than expected. */
+ object.last_error = OlmErrorCode::OLM_PICKLE_EXTRA_DATA;
+ return std::size_t(-1);
+ }
+
+ if (pubkey != NULL) {
+ olm::encode_base64(
+ (const uint8_t *)object.key_pair.public_key.public_key,
+ CURVE25519_KEY_LENGTH,
+ (uint8_t *)pubkey
+ );
+ }
+
+ return pickled_length;
+}
+
+size_t olm_pk_max_plaintext_length(
+ const OlmPkDecryption * decryption,
+ size_t ciphertext_length
+) {
+ return _olm_cipher_aes_sha_256_ops.decrypt_max_plaintext_length(
+ olm_pk_cipher, olm::decode_base64_length(ciphertext_length)
+ );
+}
+
+size_t olm_pk_decrypt(
+ OlmPkDecryption * decryption,
+ void const * ephemeral_key, size_t ephemeral_key_length,
+ void const * mac, size_t mac_length,
+ void * ciphertext, size_t ciphertext_length,
+ void * plaintext, size_t max_plaintext_length
+) {
+ if (max_plaintext_length
+ < olm_pk_max_plaintext_length(decryption, ciphertext_length)) {
+ decryption->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ size_t raw_ciphertext_length = olm::decode_base64_length(ciphertext_length);
+
+ if (ephemeral_key_length != olm::encode_base64_length(CURVE25519_KEY_LENGTH)
+ || mac_length != olm::encode_base64_length(MAC_LENGTH)
+ || raw_ciphertext_length == std::size_t(-1)) {
+ decryption->last_error = OlmErrorCode::OLM_INVALID_BASE64;
+ return std::size_t(-1);
+ }
+
+ struct _olm_curve25519_public_key ephemeral;
+ olm::decode_base64(
+ (const uint8_t*)ephemeral_key,
+ olm::encode_base64_length(CURVE25519_KEY_LENGTH),
+ (uint8_t *)ephemeral.public_key
+ );
+
+ olm::SharedKey secret;
+ _olm_crypto_curve25519_shared_secret(&decryption->key_pair, &ephemeral, secret);
+
+ uint8_t raw_mac[MAC_LENGTH];
+ olm::decode_base64(
+ (const uint8_t *)mac,
+ olm::encode_base64_length(MAC_LENGTH),
+ raw_mac
+ );
+
+ olm::decode_base64(
+ (const uint8_t *)ciphertext,
+ ciphertext_length,
+ (uint8_t *)ciphertext
+ );
+
+ size_t result = _olm_cipher_aes_sha_256_ops.decrypt(
+ olm_pk_cipher,
+ secret, sizeof(secret),
+ (uint8_t *) raw_mac, MAC_LENGTH,
+ (const uint8_t *) ciphertext, raw_ciphertext_length,
+ (uint8_t *) plaintext, max_plaintext_length
+ );
+ if (result == std::size_t(-1)) {
+ // we already checked the buffer sizes, so the only error that decrypt
+ // will return is if the MAC is incorrect
+ decryption->last_error =
+ OlmErrorCode::OLM_BAD_MESSAGE_MAC;
+ return std::size_t(-1);
+ } else {
+ return result;
+ }
+}
+
+size_t olm_pk_get_private_key(
+ OlmPkDecryption * decryption,
+ void *private_key, size_t private_key_length
+) {
+ if (private_key_length < olm_pk_private_key_length()) {
+ decryption->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::memcpy(
+ private_key,
+ decryption->key_pair.private_key.private_key,
+ olm_pk_private_key_length()
+ );
+ return olm_pk_private_key_length();
+}
+
+struct OlmPkSigning {
+ OlmErrorCode last_error;
+ _olm_ed25519_key_pair key_pair;
+};
+
+size_t olm_pk_signing_size(void) {
+ return sizeof(OlmPkSigning);
+}
+
+OlmPkSigning *olm_pk_signing(void * memory) {
+ olm::unset(memory, sizeof(OlmPkSigning));
+ return new(memory) OlmPkSigning;
+}
+
+const char * olm_pk_signing_last_error(const OlmPkSigning * sign) {
+ auto error = sign->last_error;
+ return _olm_error_to_string(error);
+}
+
+OlmErrorCode olm_pk_signing_last_error_code(const OlmPkSigning * sign) {
+ return sign->last_error;
+}
+
+size_t olm_clear_pk_signing(OlmPkSigning *sign) {
+ /* Clear the memory backing the signing */
+ olm::unset(sign, sizeof(OlmPkSigning));
+ /* Initialise a fresh signing object in case someone tries to use it */
+ new(sign) OlmPkSigning();
+ return sizeof(OlmPkSigning);
+}
+
+size_t olm_pk_signing_seed_length(void) {
+ return ED25519_RANDOM_LENGTH;
+}
+
+size_t olm_pk_signing_public_key_length(void) {
+ return olm::encode_base64_length(ED25519_PUBLIC_KEY_LENGTH);
+}
+
+size_t olm_pk_signing_key_from_seed(
+ OlmPkSigning * signing,
+ void * pubkey, size_t pubkey_length,
+ const void * seed, size_t seed_length
+) {
+ if (pubkey_length < olm_pk_signing_public_key_length()) {
+ signing->last_error =
+ OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ if (seed_length < olm_pk_signing_seed_length()) {
+ signing->last_error =
+ OlmErrorCode::OLM_INPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ _olm_crypto_ed25519_generate_key((const uint8_t *) seed, &signing->key_pair);
+ olm::encode_base64(
+ (const uint8_t *)signing->key_pair.public_key.public_key,
+ ED25519_PUBLIC_KEY_LENGTH,
+ (uint8_t *)pubkey
+ );
+ return 0;
+}
+
+size_t olm_pk_signature_length(void) {
+ return olm::encode_base64_length(ED25519_SIGNATURE_LENGTH);
+}
+
+size_t olm_pk_sign(
+ OlmPkSigning *signing,
+ uint8_t const * message, size_t message_length,
+ uint8_t * signature, size_t signature_length
+) {
+ if (signature_length < olm_pk_signature_length()) {
+ signing->last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ uint8_t *raw_sig = signature + olm_pk_signature_length() - ED25519_SIGNATURE_LENGTH;
+ _olm_crypto_ed25519_sign(
+ &signing->key_pair,
+ message, message_length, raw_sig
+ );
+ olm::encode_base64(raw_sig, ED25519_SIGNATURE_LENGTH, signature);
+ return olm_pk_signature_length();
+}
+
+}
diff --git a/ext/olm/src/ratchet.cpp b/ext/olm/src/ratchet.cpp
new file mode 100644
index 0000000..1d284a6
--- /dev/null
+++ b/ext/olm/src/ratchet.cpp
@@ -0,0 +1,625 @@
+/* Copyright 2015, 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/ratchet.hh"
+#include "olm/message.hh"
+#include "olm/memory.hh"
+#include "olm/cipher.h"
+#include "olm/pickle.hh"
+
+#include <cstring>
+
+namespace {
+
+static const std::uint8_t PROTOCOL_VERSION = 3;
+static const std::uint8_t MESSAGE_KEY_SEED[1] = {0x01};
+static const std::uint8_t CHAIN_KEY_SEED[1] = {0x02};
+static const std::size_t MAX_MESSAGE_GAP = 2000;
+
+
+/**
+ * Advance the root key, creating a new message chain.
+ *
+ * @param root_key previous root key R(n-1)
+ * @param our_key our new ratchet key T(n)
+ * @param their_key their most recent ratchet key T(n-1)
+ * @param info table of constants for the ratchet function
+ * @param new_root_key[out] returns the new root key R(n)
+ * @param new_chain_key[out] returns the first chain key in the new chain
+ * C(n,0)
+ */
+static void create_chain_key(
+ olm::SharedKey const & root_key,
+ _olm_curve25519_key_pair const & our_key,
+ _olm_curve25519_public_key const & their_key,
+ olm::KdfInfo const & info,
+ olm::SharedKey & new_root_key,
+ olm::ChainKey & new_chain_key
+) {
+ olm::SharedKey secret;
+ _olm_crypto_curve25519_shared_secret(&our_key, &their_key, secret);
+ std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
+ _olm_crypto_hkdf_sha256(
+ secret, sizeof(secret),
+ root_key, sizeof(root_key),
+ info.ratchet_info, info.ratchet_info_length,
+ derived_secrets, sizeof(derived_secrets)
+ );
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(new_root_key, pos);
+ pos = olm::load_array(new_chain_key.key, pos);
+ new_chain_key.index = 0;
+ olm::unset(derived_secrets);
+ olm::unset(secret);
+}
+
+
+static void advance_chain_key(
+ olm::ChainKey const & chain_key,
+ olm::ChainKey & new_chain_key
+) {
+ _olm_crypto_hmac_sha256(
+ chain_key.key, sizeof(chain_key.key),
+ CHAIN_KEY_SEED, sizeof(CHAIN_KEY_SEED),
+ new_chain_key.key
+ );
+ new_chain_key.index = chain_key.index + 1;
+}
+
+
+static void create_message_keys(
+ olm::ChainKey const & chain_key,
+ olm::KdfInfo const & info,
+ olm::MessageKey & message_key) {
+ _olm_crypto_hmac_sha256(
+ chain_key.key, sizeof(chain_key.key),
+ MESSAGE_KEY_SEED, sizeof(MESSAGE_KEY_SEED),
+ message_key.key
+ );
+ message_key.index = chain_key.index;
+}
+
+
+static std::size_t verify_mac_and_decrypt(
+ _olm_cipher const *cipher,
+ olm::MessageKey const & message_key,
+ olm::MessageReader const & reader,
+ std::uint8_t * plaintext, std::size_t max_plaintext_length
+) {
+ return cipher->ops->decrypt(
+ cipher,
+ message_key.key, sizeof(message_key.key),
+ reader.input, reader.input_length,
+ reader.ciphertext, reader.ciphertext_length,
+ plaintext, max_plaintext_length
+ );
+}
+
+
+static std::size_t verify_mac_and_decrypt_for_existing_chain(
+ olm::Ratchet const & session,
+ olm::ChainKey const & chain,
+ olm::MessageReader const & reader,
+ std::uint8_t * plaintext, std::size_t max_plaintext_length
+) {
+ if (reader.counter < chain.index) {
+ return std::size_t(-1);
+ }
+
+ /* Limit the number of hashes we're prepared to compute */
+ if (reader.counter - chain.index > MAX_MESSAGE_GAP) {
+ return std::size_t(-1);
+ }
+
+ olm::ChainKey new_chain = chain;
+
+ while (new_chain.index < reader.counter) {
+ advance_chain_key(new_chain, new_chain);
+ }
+
+ olm::MessageKey message_key;
+ create_message_keys(new_chain, session.kdf_info, message_key);
+
+ std::size_t result = verify_mac_and_decrypt(
+ session.ratchet_cipher, message_key, reader,
+ plaintext, max_plaintext_length
+ );
+
+ olm::unset(new_chain);
+ return result;
+}
+
+
+static std::size_t verify_mac_and_decrypt_for_new_chain(
+ olm::Ratchet const & session,
+ olm::MessageReader const & reader,
+ std::uint8_t * plaintext, std::size_t max_plaintext_length
+) {
+ olm::SharedKey new_root_key;
+ olm::ReceiverChain new_chain;
+
+ /* They shouldn't move to a new chain until we've sent them a message
+ * acknowledging the last one */
+ if (session.sender_chain.empty()) {
+ return std::size_t(-1);
+ }
+
+ /* Limit the number of hashes we're prepared to compute */
+ if (reader.counter > MAX_MESSAGE_GAP) {
+ return std::size_t(-1);
+ }
+ olm::load_array(new_chain.ratchet_key.public_key, reader.ratchet_key);
+
+ create_chain_key(
+ session.root_key, session.sender_chain[0].ratchet_key,
+ new_chain.ratchet_key, session.kdf_info,
+ new_root_key, new_chain.chain_key
+ );
+ std::size_t result = verify_mac_and_decrypt_for_existing_chain(
+ session, new_chain.chain_key, reader,
+ plaintext, max_plaintext_length
+ );
+ olm::unset(new_root_key);
+ olm::unset(new_chain);
+ return result;
+}
+
+} // namespace
+
+
+olm::Ratchet::Ratchet(
+ olm::KdfInfo const & kdf_info,
+ _olm_cipher const * ratchet_cipher
+) : kdf_info(kdf_info),
+ ratchet_cipher(ratchet_cipher),
+ last_error(OlmErrorCode::OLM_SUCCESS) {
+}
+
+
+void olm::Ratchet::initialise_as_bob(
+ std::uint8_t const * shared_secret, std::size_t shared_secret_length,
+ _olm_curve25519_public_key const & their_ratchet_key
+) {
+ std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
+ _olm_crypto_hkdf_sha256(
+ shared_secret, shared_secret_length,
+ nullptr, 0,
+ kdf_info.root_info, kdf_info.root_info_length,
+ derived_secrets, sizeof(derived_secrets)
+ );
+ receiver_chains.insert();
+ receiver_chains[0].chain_key.index = 0;
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(root_key, pos);
+ pos = olm::load_array(receiver_chains[0].chain_key.key, pos);
+ receiver_chains[0].ratchet_key = their_ratchet_key;
+ olm::unset(derived_secrets);
+}
+
+
+void olm::Ratchet::initialise_as_alice(
+ std::uint8_t const * shared_secret, std::size_t shared_secret_length,
+ _olm_curve25519_key_pair const & our_ratchet_key
+) {
+ std::uint8_t derived_secrets[2 * olm::OLM_SHARED_KEY_LENGTH];
+ _olm_crypto_hkdf_sha256(
+ shared_secret, shared_secret_length,
+ nullptr, 0,
+ kdf_info.root_info, kdf_info.root_info_length,
+ derived_secrets, sizeof(derived_secrets)
+ );
+ sender_chain.insert();
+ sender_chain[0].chain_key.index = 0;
+ std::uint8_t const * pos = derived_secrets;
+ pos = olm::load_array(root_key, pos);
+ pos = olm::load_array(sender_chain[0].chain_key.key, pos);
+ sender_chain[0].ratchet_key = our_ratchet_key;
+ olm::unset(derived_secrets);
+}
+
+namespace olm {
+
+
+static std::size_t pickle_length(
+ const olm::SharedKey & value
+) {
+ return olm::OLM_SHARED_KEY_LENGTH;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ const olm::SharedKey & value
+) {
+ return olm::pickle_bytes(pos, value, olm::OLM_SHARED_KEY_LENGTH);
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::SharedKey & value
+) {
+ return olm::unpickle_bytes(pos, end, value, olm::OLM_SHARED_KEY_LENGTH);
+}
+
+
+static std::size_t pickle_length(
+ const olm::SenderChain & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(value.ratchet_key);
+ length += olm::pickle_length(value.chain_key.key);
+ length += olm::pickle_length(value.chain_key.index);
+ return length;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ const olm::SenderChain & value
+) {
+ pos = olm::pickle(pos, value.ratchet_key);
+ pos = olm::pickle(pos, value.chain_key.key);
+ pos = olm::pickle(pos, value.chain_key.index);
+ return pos;
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::SenderChain & value
+) {
+ pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos);
+ return pos;
+}
+
+static std::size_t pickle_length(
+ const olm::ReceiverChain & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(value.ratchet_key);
+ length += olm::pickle_length(value.chain_key.key);
+ length += olm::pickle_length(value.chain_key.index);
+ return length;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ const olm::ReceiverChain & value
+) {
+ pos = olm::pickle(pos, value.ratchet_key);
+ pos = olm::pickle(pos, value.chain_key.key);
+ pos = olm::pickle(pos, value.chain_key.index);
+ return pos;
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::ReceiverChain & value
+) {
+ pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.chain_key.key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.chain_key.index); UNPICKLE_OK(pos);
+ return pos;
+}
+
+
+static std::size_t pickle_length(
+ const olm::SkippedMessageKey & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(value.ratchet_key);
+ length += olm::pickle_length(value.message_key.key);
+ length += olm::pickle_length(value.message_key.index);
+ return length;
+}
+
+
+static std::uint8_t * pickle(
+ std::uint8_t * pos,
+ const olm::SkippedMessageKey & value
+) {
+ pos = olm::pickle(pos, value.ratchet_key);
+ pos = olm::pickle(pos, value.message_key.key);
+ pos = olm::pickle(pos, value.message_key.index);
+ return pos;
+}
+
+
+static std::uint8_t const * unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::SkippedMessageKey & value
+) {
+ pos = olm::unpickle(pos, end, value.ratchet_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.message_key.key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.message_key.index); UNPICKLE_OK(pos);
+ return pos;
+}
+
+
+} // namespace olm
+
+
+std::size_t olm::pickle_length(
+ olm::Ratchet const & value
+) {
+ std::size_t length = 0;
+ length += olm::OLM_SHARED_KEY_LENGTH;
+ length += olm::pickle_length(value.sender_chain);
+ length += olm::pickle_length(value.receiver_chains);
+ length += olm::pickle_length(value.skipped_message_keys);
+ return length;
+}
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ olm::Ratchet const & value
+) {
+ pos = pickle(pos, value.root_key);
+ pos = pickle(pos, value.sender_chain);
+ pos = pickle(pos, value.receiver_chains);
+ pos = pickle(pos, value.skipped_message_keys);
+ return pos;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ olm::Ratchet & value,
+ bool includes_chain_index
+) {
+ pos = unpickle(pos, end, value.root_key); UNPICKLE_OK(pos);
+ pos = unpickle(pos, end, value.sender_chain); UNPICKLE_OK(pos);
+ pos = unpickle(pos, end, value.receiver_chains); UNPICKLE_OK(pos);
+ pos = unpickle(pos, end, value.skipped_message_keys); UNPICKLE_OK(pos);
+
+ // pickle v 0x80000001 includes a chain index; pickle v1 does not.
+ if (includes_chain_index) {
+ std::uint32_t dummy;
+ pos = unpickle(pos, end, dummy); UNPICKLE_OK(pos);
+ }
+ return pos;
+}
+
+
+std::size_t olm::Ratchet::encrypt_output_length(
+ std::size_t plaintext_length
+) const {
+ std::size_t counter = 0;
+ if (!sender_chain.empty()) {
+ counter = sender_chain[0].chain_key.index;
+ }
+ std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
+ ratchet_cipher,
+ plaintext_length
+ );
+ return olm::encode_message_length(
+ counter, CURVE25519_KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
+ );
+}
+
+
+std::size_t olm::Ratchet::encrypt_random_length() const {
+ return sender_chain.empty() ? CURVE25519_RANDOM_LENGTH : 0;
+}
+
+
+std::size_t olm::Ratchet::encrypt(
+ std::uint8_t const * plaintext, std::size_t plaintext_length,
+ std::uint8_t const * random, std::size_t random_length,
+ std::uint8_t * output, std::size_t max_output_length
+) {
+ std::size_t output_length = encrypt_output_length(plaintext_length);
+
+ if (random_length < encrypt_random_length()) {
+ last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+ if (max_output_length < output_length) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ if (sender_chain.empty()) {
+ sender_chain.insert();
+ _olm_crypto_curve25519_generate_key(random, &sender_chain[0].ratchet_key);
+ create_chain_key(
+ root_key,
+ sender_chain[0].ratchet_key,
+ receiver_chains[0].ratchet_key,
+ kdf_info,
+ root_key, sender_chain[0].chain_key
+ );
+ }
+
+ MessageKey keys;
+ create_message_keys(sender_chain[0].chain_key, kdf_info, keys);
+ advance_chain_key(sender_chain[0].chain_key, sender_chain[0].chain_key);
+
+ std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
+ ratchet_cipher,
+ plaintext_length
+ );
+ std::uint32_t counter = keys.index;
+ _olm_curve25519_public_key const & ratchet_key =
+ sender_chain[0].ratchet_key.public_key;
+
+ olm::MessageWriter writer;
+
+ olm::encode_message(
+ writer, PROTOCOL_VERSION, counter, CURVE25519_KEY_LENGTH,
+ ciphertext_length,
+ output
+ );
+
+ olm::store_array(writer.ratchet_key, ratchet_key.public_key);
+
+ ratchet_cipher->ops->encrypt(
+ ratchet_cipher,
+ keys.key, sizeof(keys.key),
+ plaintext, plaintext_length,
+ writer.ciphertext, ciphertext_length,
+ output, output_length
+ );
+
+ olm::unset(keys);
+ return output_length;
+}
+
+
+std::size_t olm::Ratchet::decrypt_max_plaintext_length(
+ std::uint8_t const * input, std::size_t input_length
+) {
+ olm::MessageReader reader;
+ olm::decode_message(
+ reader, input, input_length,
+ ratchet_cipher->ops->mac_length(ratchet_cipher)
+ );
+
+ if (!reader.ciphertext) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+
+ return ratchet_cipher->ops->decrypt_max_plaintext_length(
+ ratchet_cipher, reader.ciphertext_length);
+}
+
+
+std::size_t olm::Ratchet::decrypt(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * plaintext, std::size_t max_plaintext_length
+) {
+ olm::MessageReader reader;
+ olm::decode_message(
+ reader, input, input_length,
+ ratchet_cipher->ops->mac_length(ratchet_cipher)
+ );
+
+ if (reader.version != PROTOCOL_VERSION) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_VERSION;
+ return std::size_t(-1);
+ }
+
+ if (!reader.has_counter || !reader.ratchet_key || !reader.ciphertext) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+
+ std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
+ ratchet_cipher,
+ reader.ciphertext_length
+ );
+
+ if (max_plaintext_length < max_length) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+
+ if (reader.ratchet_key_length != CURVE25519_KEY_LENGTH) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+
+ ReceiverChain * chain = nullptr;
+
+ for (olm::ReceiverChain & receiver_chain : receiver_chains) {
+ if (0 == std::memcmp(
+ receiver_chain.ratchet_key.public_key, reader.ratchet_key,
+ CURVE25519_KEY_LENGTH
+ )) {
+ chain = &receiver_chain;
+ break;
+ }
+ }
+
+ std::size_t result = std::size_t(-1);
+
+ if (!chain) {
+ result = verify_mac_and_decrypt_for_new_chain(
+ *this, reader, plaintext, max_plaintext_length
+ );
+ } else if (chain->chain_key.index > reader.counter) {
+ /* Chain already advanced beyond the key for this message
+ * Check if the message keys are in the skipped key list. */
+ for (olm::SkippedMessageKey & skipped : skipped_message_keys) {
+ if (reader.counter == skipped.message_key.index
+ && 0 == std::memcmp(
+ skipped.ratchet_key.public_key, reader.ratchet_key,
+ CURVE25519_KEY_LENGTH
+ )
+ ) {
+ /* Found the key for this message. Check the MAC. */
+
+ result = verify_mac_and_decrypt(
+ ratchet_cipher, skipped.message_key, reader,
+ plaintext, max_plaintext_length
+ );
+
+ if (result != std::size_t(-1)) {
+ /* Remove the key from the skipped keys now that we've
+ * decoded the message it corresponds to. */
+ olm::unset(skipped);
+ skipped_message_keys.erase(&skipped);
+ return result;
+ }
+ }
+ }
+ } else {
+ result = verify_mac_and_decrypt_for_existing_chain(
+ *this, chain->chain_key,
+ reader, plaintext, max_plaintext_length
+ );
+ }
+
+ if (result == std::size_t(-1)) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC;
+ return std::size_t(-1);
+ }
+
+ if (!chain) {
+ /* They have started using a new ephemeral ratchet key.
+ * We need to derive a new set of chain keys.
+ * We can discard our previous ephemeral ratchet key.
+ * We will generate a new key when we send the next message. */
+
+ chain = receiver_chains.insert();
+ olm::load_array(chain->ratchet_key.public_key, reader.ratchet_key);
+
+ // TODO: we've already done this once, in
+ // verify_mac_and_decrypt_for_new_chain(). we could reuse the result.
+ create_chain_key(
+ root_key, sender_chain[0].ratchet_key, chain->ratchet_key,
+ kdf_info, root_key, chain->chain_key
+ );
+
+ olm::unset(sender_chain[0]);
+ sender_chain.erase(sender_chain.begin());
+ }
+
+ while (chain->chain_key.index < reader.counter) {
+ olm::SkippedMessageKey & key = *skipped_message_keys.insert();
+ create_message_keys(chain->chain_key, kdf_info, key.message_key);
+ key.ratchet_key = chain->ratchet_key;
+ advance_chain_key(chain->chain_key, chain->chain_key);
+ }
+
+ advance_chain_key(chain->chain_key, chain->chain_key);
+
+ return result;
+}
diff --git a/ext/olm/src/sas.c b/ext/olm/src/sas.c
new file mode 100644
index 0000000..d9cec7e
--- /dev/null
+++ b/ext/olm/src/sas.c
@@ -0,0 +1,229 @@
+/* Copyright 2018-2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/sas.h"
+#include "olm/base64.h"
+#include "olm/crypto.h"
+#include "olm/error.h"
+#include "olm/memory.h"
+
+struct OlmSAS {
+ enum OlmErrorCode last_error;
+ struct _olm_curve25519_key_pair curve25519_key;
+ uint8_t secret[CURVE25519_SHARED_SECRET_LENGTH];
+ int their_key_set;
+};
+
+const char * olm_sas_last_error(
+ const OlmSAS * sas
+) {
+ return _olm_error_to_string(sas->last_error);
+}
+
+enum OlmErrorCode olm_sas_last_error_code(
+ const OlmSAS * sas
+) {
+ return sas->last_error;
+}
+
+size_t olm_sas_size(void) {
+ return sizeof(OlmSAS);
+}
+
+OlmSAS * olm_sas(
+ void * memory
+) {
+ _olm_unset(memory, sizeof(OlmSAS));
+ return (OlmSAS *) memory;
+}
+
+size_t olm_clear_sas(
+ OlmSAS * sas
+) {
+ _olm_unset(sas, sizeof(OlmSAS));
+ return sizeof(OlmSAS);
+}
+
+size_t olm_create_sas_random_length(const OlmSAS * sas) {
+ return CURVE25519_KEY_LENGTH;
+}
+
+size_t olm_create_sas(
+ OlmSAS * sas,
+ void * random, size_t random_length
+) {
+ if (random_length < olm_create_sas_random_length(sas)) {
+ sas->last_error = OLM_NOT_ENOUGH_RANDOM;
+ return (size_t)-1;
+ }
+ _olm_crypto_curve25519_generate_key((uint8_t *) random, &sas->curve25519_key);
+ sas->their_key_set = 0;
+ return 0;
+}
+
+size_t olm_sas_pubkey_length(const OlmSAS * sas) {
+ return _olm_encode_base64_length(CURVE25519_KEY_LENGTH);
+}
+
+size_t olm_sas_get_pubkey(
+ OlmSAS * sas,
+ void * pubkey, size_t pubkey_length
+) {
+ if (pubkey_length < olm_sas_pubkey_length(sas)) {
+ sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+ _olm_encode_base64(
+ (const uint8_t *)sas->curve25519_key.public_key.public_key,
+ CURVE25519_KEY_LENGTH,
+ (uint8_t *)pubkey
+ );
+ return 0;
+}
+
+size_t olm_sas_set_their_key(
+ OlmSAS *sas,
+ void * their_key, size_t their_key_length
+) {
+ if (their_key_length < olm_sas_pubkey_length(sas)) {
+ sas->last_error = OLM_INPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+
+ size_t ret = _olm_decode_base64(their_key, their_key_length, their_key);
+ if (ret == (size_t)-1) {
+ sas->last_error = OLM_INVALID_BASE64;
+ return (size_t)-1;
+ }
+
+ _olm_crypto_curve25519_shared_secret(&sas->curve25519_key, their_key, sas->secret);
+ sas->their_key_set = 1;
+ return 0;
+}
+
+int olm_sas_is_their_key_set(
+ const OlmSAS *sas
+) {
+ return sas->their_key_set;
+}
+
+size_t olm_sas_generate_bytes(
+ OlmSAS * sas,
+ const void * info, size_t info_length,
+ void * output, size_t output_length
+) {
+ if (!sas->their_key_set) {
+ sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET;
+ return (size_t)-1;
+ }
+ _olm_crypto_hkdf_sha256(
+ sas->secret, sizeof(sas->secret),
+ NULL, 0,
+ (const uint8_t *) info, info_length,
+ output, output_length
+ );
+ return 0;
+}
+
+size_t olm_sas_mac_length(
+ const OlmSAS *sas
+) {
+ return _olm_encode_base64_length(SHA256_OUTPUT_LENGTH);
+}
+
+// A version of the calculate mac function that produces base64 strings that are
+// compatible with other base64 implementations.
+size_t olm_sas_calculate_mac_fixed_base64(
+ OlmSAS * sas,
+ const void * input, size_t input_length,
+ const void * info, size_t info_length,
+ void * mac, size_t mac_length
+) {
+ if (mac_length < olm_sas_mac_length(sas)) {
+ sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+ if (!sas->their_key_set) {
+ sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET;
+ return (size_t)-1;
+ }
+ uint8_t key[32];
+ _olm_crypto_hkdf_sha256(
+ sas->secret, sizeof(sas->secret),
+ NULL, 0,
+ (const uint8_t *) info, info_length,
+ key, 32
+ );
+
+ uint8_t temp_mac[32];
+ _olm_crypto_hmac_sha256(key, 32, input, input_length, temp_mac);
+ _olm_encode_base64((const uint8_t *)temp_mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac);
+
+ return 0;
+}
+
+
+size_t olm_sas_calculate_mac(
+ OlmSAS * sas,
+ const void * input, size_t input_length,
+ const void * info, size_t info_length,
+ void * mac, size_t mac_length
+) {
+ if (mac_length < olm_sas_mac_length(sas)) {
+ sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+ if (!sas->their_key_set) {
+ sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET;
+ return (size_t)-1;
+ }
+ uint8_t key[32];
+ _olm_crypto_hkdf_sha256(
+ sas->secret, sizeof(sas->secret),
+ NULL, 0,
+ (const uint8_t *) info, info_length,
+ key, 32
+ );
+ _olm_crypto_hmac_sha256(key, 32, input, input_length, mac);
+ _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac);
+ return 0;
+}
+
+// for compatibility with an old version of Riot
+size_t olm_sas_calculate_mac_long_kdf(
+ OlmSAS * sas,
+ const void * input, size_t input_length,
+ const void * info, size_t info_length,
+ void * mac, size_t mac_length
+) {
+ if (mac_length < olm_sas_mac_length(sas)) {
+ sas->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return (size_t)-1;
+ }
+ if (!sas->their_key_set) {
+ sas->last_error = OLM_SAS_THEIR_KEY_NOT_SET;
+ return (size_t)-1;
+ }
+ uint8_t key[256];
+ _olm_crypto_hkdf_sha256(
+ sas->secret, sizeof(sas->secret),
+ NULL, 0,
+ (const uint8_t *) info, info_length,
+ key, 256
+ );
+ _olm_crypto_hmac_sha256(key, 256, input, input_length, mac);
+ _olm_encode_base64((const uint8_t *)mac, SHA256_OUTPUT_LENGTH, (uint8_t *)mac);
+ return 0;
+}
diff --git a/ext/olm/src/session.cpp b/ext/olm/src/session.cpp
new file mode 100644
index 0000000..732e0c0
--- /dev/null
+++ b/ext/olm/src/session.cpp
@@ -0,0 +1,531 @@
+/* Copyright 2015, 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "olm/session.hh"
+#include "olm/cipher.h"
+#include "olm/crypto.h"
+#include "olm/account.hh"
+#include "olm/memory.hh"
+#include "olm/message.hh"
+#include "olm/pickle.hh"
+
+#include <cstring>
+#include <stdio.h>
+
+namespace {
+
+static const std::uint8_t PROTOCOL_VERSION = 0x3;
+
+static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT";
+static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET";
+static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS";
+
+static const olm::KdfInfo OLM_KDF_INFO = {
+ ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1,
+ RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1
+};
+
+static const struct _olm_cipher_aes_sha_256 OLM_CIPHER =
+ OLM_CIPHER_INIT_AES_SHA_256(CIPHER_KDF_INFO);
+
+} // namespace
+
+olm::Session::Session(
+) : ratchet(OLM_KDF_INFO, OLM_CIPHER_BASE(&OLM_CIPHER)),
+ last_error(OlmErrorCode::OLM_SUCCESS),
+ received_message(false) {
+
+}
+
+
+std::size_t olm::Session::new_outbound_session_random_length() const {
+ return CURVE25519_RANDOM_LENGTH * 2;
+}
+
+
+std::size_t olm::Session::new_outbound_session(
+ olm::Account const & local_account,
+ _olm_curve25519_public_key const & identity_key,
+ _olm_curve25519_public_key const & one_time_key,
+ std::uint8_t const * random, std::size_t random_length
+) {
+ if (random_length < new_outbound_session_random_length()) {
+ last_error = OlmErrorCode::OLM_NOT_ENOUGH_RANDOM;
+ return std::size_t(-1);
+ }
+
+ _olm_curve25519_key_pair base_key;
+ _olm_crypto_curve25519_generate_key(random, &base_key);
+
+ _olm_curve25519_key_pair ratchet_key;
+ _olm_crypto_curve25519_generate_key(random + CURVE25519_RANDOM_LENGTH, &ratchet_key);
+
+ _olm_curve25519_key_pair const & alice_identity_key_pair = (
+ local_account.identity_keys.curve25519_key
+ );
+
+ received_message = false;
+ alice_identity_key = alice_identity_key_pair.public_key;
+ alice_base_key = base_key.public_key;
+ bob_one_time_key = one_time_key;
+
+ // Calculate the shared secret S via triple DH
+ std::uint8_t secret[3 * CURVE25519_SHARED_SECRET_LENGTH];
+ std::uint8_t * pos = secret;
+
+ _olm_crypto_curve25519_shared_secret(&alice_identity_key_pair, &one_time_key, pos);
+ pos += CURVE25519_SHARED_SECRET_LENGTH;
+ _olm_crypto_curve25519_shared_secret(&base_key, &identity_key, pos);
+ pos += CURVE25519_SHARED_SECRET_LENGTH;
+ _olm_crypto_curve25519_shared_secret(&base_key, &one_time_key, pos);
+
+ ratchet.initialise_as_alice(secret, sizeof(secret), ratchet_key);
+
+ olm::unset(base_key);
+ olm::unset(ratchet_key);
+ olm::unset(secret);
+
+ return std::size_t(0);
+}
+
+namespace {
+
+static bool check_message_fields(
+ olm::PreKeyMessageReader & reader, bool have_their_identity_key
+) {
+ bool ok = true;
+ ok = ok && (have_their_identity_key || reader.identity_key);
+ if (reader.identity_key) {
+ ok = ok && reader.identity_key_length == CURVE25519_KEY_LENGTH;
+ }
+ ok = ok && reader.message;
+ ok = ok && reader.base_key;
+ ok = ok && reader.base_key_length == CURVE25519_KEY_LENGTH;
+ ok = ok && reader.one_time_key;
+ ok = ok && reader.one_time_key_length == CURVE25519_KEY_LENGTH;
+ return ok;
+}
+
+} // namespace
+
+
+std::size_t olm::Session::new_inbound_session(
+ olm::Account & local_account,
+ _olm_curve25519_public_key const * their_identity_key,
+ std::uint8_t const * one_time_key_message, std::size_t message_length
+) {
+ olm::PreKeyMessageReader reader;
+ decode_one_time_key_message(reader, one_time_key_message, message_length);
+
+ if (!check_message_fields(reader, their_identity_key)) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+
+ if (reader.identity_key && their_identity_key) {
+ bool same = 0 == std::memcmp(
+ their_identity_key->public_key, reader.identity_key, CURVE25519_KEY_LENGTH
+ );
+ if (!same) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID;
+ return std::size_t(-1);
+ }
+ }
+
+ olm::load_array(alice_identity_key.public_key, reader.identity_key);
+ olm::load_array(alice_base_key.public_key, reader.base_key);
+ olm::load_array(bob_one_time_key.public_key, reader.one_time_key);
+
+ olm::MessageReader message_reader;
+ decode_message(
+ message_reader, reader.message, reader.message_length,
+ ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher)
+ );
+
+ if (!message_reader.ratchet_key
+ || message_reader.ratchet_key_length != CURVE25519_KEY_LENGTH) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+
+ _olm_curve25519_public_key ratchet_key;
+ olm::load_array(ratchet_key.public_key, message_reader.ratchet_key);
+
+ olm::OneTimeKey const * our_one_time_key = local_account.lookup_key(
+ bob_one_time_key
+ );
+
+ if (!our_one_time_key) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_KEY_ID;
+ return std::size_t(-1);
+ }
+
+ _olm_curve25519_key_pair const & bob_identity_key = (
+ local_account.identity_keys.curve25519_key
+ );
+ _olm_curve25519_key_pair const & bob_one_time_key = our_one_time_key->key;
+
+ // Calculate the shared secret S via triple DH
+ std::uint8_t secret[CURVE25519_SHARED_SECRET_LENGTH * 3];
+ std::uint8_t * pos = secret;
+ _olm_crypto_curve25519_shared_secret(&bob_one_time_key, &alice_identity_key, pos);
+ pos += CURVE25519_SHARED_SECRET_LENGTH;
+ _olm_crypto_curve25519_shared_secret(&bob_identity_key, &alice_base_key, pos);
+ pos += CURVE25519_SHARED_SECRET_LENGTH;
+ _olm_crypto_curve25519_shared_secret(&bob_one_time_key, &alice_base_key, pos);
+
+ ratchet.initialise_as_bob(secret, sizeof(secret), ratchet_key);
+
+ olm::unset(secret);
+
+ return std::size_t(0);
+}
+
+
+std::size_t olm::Session::session_id_length() const {
+ return SHA256_OUTPUT_LENGTH;
+}
+
+
+std::size_t olm::Session::session_id(
+ std::uint8_t * id, std::size_t id_length
+) {
+ if (id_length < session_id_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::uint8_t tmp[CURVE25519_KEY_LENGTH * 3];
+ std::uint8_t * pos = tmp;
+ pos = olm::store_array(pos, alice_identity_key.public_key);
+ pos = olm::store_array(pos, alice_base_key.public_key);
+ pos = olm::store_array(pos, bob_one_time_key.public_key);
+ _olm_crypto_sha256(tmp, sizeof(tmp), id);
+ return session_id_length();
+}
+
+
+bool olm::Session::matches_inbound_session(
+ _olm_curve25519_public_key const * their_identity_key,
+ std::uint8_t const * one_time_key_message, std::size_t message_length
+) const {
+ olm::PreKeyMessageReader reader;
+ decode_one_time_key_message(reader, one_time_key_message, message_length);
+
+ if (!check_message_fields(reader, their_identity_key)) {
+ return false;
+ }
+
+ bool same = true;
+ if (reader.identity_key) {
+ same = same && 0 == std::memcmp(
+ reader.identity_key, alice_identity_key.public_key, CURVE25519_KEY_LENGTH
+ );
+ }
+ if (their_identity_key) {
+ same = same && 0 == std::memcmp(
+ their_identity_key->public_key, alice_identity_key.public_key,
+ CURVE25519_KEY_LENGTH
+ );
+ }
+ same = same && 0 == std::memcmp(
+ reader.base_key, alice_base_key.public_key, CURVE25519_KEY_LENGTH
+ );
+ same = same && 0 == std::memcmp(
+ reader.one_time_key, bob_one_time_key.public_key, CURVE25519_KEY_LENGTH
+ );
+ return same;
+}
+
+
+olm::MessageType olm::Session::encrypt_message_type() const {
+ if (received_message) {
+ return olm::MessageType::MESSAGE;
+ } else {
+ return olm::MessageType::PRE_KEY;
+ }
+}
+
+
+std::size_t olm::Session::encrypt_message_length(
+ std::size_t plaintext_length
+) const {
+ std::size_t message_length = ratchet.encrypt_output_length(
+ plaintext_length
+ );
+
+ if (received_message) {
+ return message_length;
+ }
+
+ return encode_one_time_key_message_length(
+ CURVE25519_KEY_LENGTH,
+ CURVE25519_KEY_LENGTH,
+ CURVE25519_KEY_LENGTH,
+ message_length
+ );
+}
+
+
+std::size_t olm::Session::encrypt_random_length() const {
+ return ratchet.encrypt_random_length();
+}
+
+
+std::size_t olm::Session::encrypt(
+ std::uint8_t const * plaintext, std::size_t plaintext_length,
+ std::uint8_t const * random, std::size_t random_length,
+ std::uint8_t * message, std::size_t message_length
+) {
+ if (message_length < encrypt_message_length(plaintext_length)) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ std::uint8_t * message_body;
+ std::size_t message_body_length = ratchet.encrypt_output_length(
+ plaintext_length
+ );
+
+ if (received_message) {
+ message_body = message;
+ } else {
+ olm::PreKeyMessageWriter writer;
+ encode_one_time_key_message(
+ writer,
+ PROTOCOL_VERSION,
+ CURVE25519_KEY_LENGTH,
+ CURVE25519_KEY_LENGTH,
+ CURVE25519_KEY_LENGTH,
+ message_body_length,
+ message
+ );
+ olm::store_array(writer.one_time_key, bob_one_time_key.public_key);
+ olm::store_array(writer.identity_key, alice_identity_key.public_key);
+ olm::store_array(writer.base_key, alice_base_key.public_key);
+ message_body = writer.message;
+ }
+
+ std::size_t result = ratchet.encrypt(
+ plaintext, plaintext_length,
+ random, random_length,
+ message_body, message_body_length
+ );
+
+ if (result == std::size_t(-1)) {
+ last_error = ratchet.last_error;
+ ratchet.last_error = OlmErrorCode::OLM_SUCCESS;
+ return result;
+ }
+
+ return result;
+}
+
+
+std::size_t olm::Session::decrypt_max_plaintext_length(
+ MessageType message_type,
+ std::uint8_t const * message, std::size_t message_length
+) {
+ std::uint8_t const * message_body;
+ std::size_t message_body_length;
+ if (message_type == olm::MessageType::MESSAGE) {
+ message_body = message;
+ message_body_length = message_length;
+ } else {
+ olm::PreKeyMessageReader reader;
+ decode_one_time_key_message(reader, message, message_length);
+ if (!reader.message) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+ message_body = reader.message;
+ message_body_length = reader.message_length;
+ }
+
+ std::size_t result = ratchet.decrypt_max_plaintext_length(
+ message_body, message_body_length
+ );
+
+ if (result == std::size_t(-1)) {
+ last_error = ratchet.last_error;
+ ratchet.last_error = OlmErrorCode::OLM_SUCCESS;
+ }
+ return result;
+}
+
+
+std::size_t olm::Session::decrypt(
+ olm::MessageType message_type,
+ std::uint8_t const * message, std::size_t message_length,
+ std::uint8_t * plaintext, std::size_t max_plaintext_length
+) {
+ std::uint8_t const * message_body;
+ std::size_t message_body_length;
+ if (message_type == olm::MessageType::MESSAGE) {
+ message_body = message;
+ message_body_length = message_length;
+ } else {
+ olm::PreKeyMessageReader reader;
+ decode_one_time_key_message(reader, message, message_length);
+ if (!reader.message) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_FORMAT;
+ return std::size_t(-1);
+ }
+ message_body = reader.message;
+ message_body_length = reader.message_length;
+ }
+
+ std::size_t result = ratchet.decrypt(
+ message_body, message_body_length, plaintext, max_plaintext_length
+ );
+
+ if (result == std::size_t(-1)) {
+ last_error = ratchet.last_error;
+ ratchet.last_error = OlmErrorCode::OLM_SUCCESS;
+ return result;
+ }
+
+ received_message = true;
+ return result;
+}
+
+// make the description end with "..." instead of stopping abruptly with no
+// warning
+void elide_description(char *end) {
+ end[-3] = '.';
+ end[-2] = '.';
+ end[-1] = '.';
+ end[0] = '\0';
+}
+
+void olm::Session::describe(char *describe_buffer, size_t buflen) {
+ // how much of the buffer is remaining (this is an int rather than a size_t
+ // because it will get compared to the return value from snprintf)
+ int remaining = buflen;
+ // do nothing if we have a zero-length buffer, or if buflen > INT_MAX,
+ // resulting in an overflow
+ if (remaining <= 0) return;
+
+ describe_buffer[0] = '\0';
+ // we need at least 23 characters to get any sort of meaningful
+ // information, so bail if we don't have that. (But more importantly, we
+ // need it to be at least 4 so that elide_description doesn't go out of
+ // bounds.)
+ if (remaining < 23) return;
+
+ int size;
+
+ // check that snprintf didn't return an error or reach the end of the buffer
+#define CHECK_SIZE_AND_ADVANCE \
+ if (size > remaining) { \
+ return elide_description(describe_buffer + remaining - 1); \
+ } else if (size > 0) { \
+ describe_buffer += size; \
+ remaining -= size; \
+ } else { \
+ return; \
+ }
+
+ size = snprintf(
+ describe_buffer, remaining,
+ "sender chain index: %ld ", ratchet.sender_chain[0].chain_key.index
+ );
+ CHECK_SIZE_AND_ADVANCE;
+
+ size = snprintf(describe_buffer, remaining, "receiver chain indices:");
+ CHECK_SIZE_AND_ADVANCE;
+
+ for (size_t i = 0; i < ratchet.receiver_chains.size(); ++i) {
+ size = snprintf(
+ describe_buffer, remaining,
+ " %ld", ratchet.receiver_chains[i].chain_key.index
+ );
+ CHECK_SIZE_AND_ADVANCE;
+ }
+
+ size = snprintf(describe_buffer, remaining, " skipped message keys:");
+ CHECK_SIZE_AND_ADVANCE;
+
+ for (size_t i = 0; i < ratchet.skipped_message_keys.size(); ++i) {
+ size = snprintf(
+ describe_buffer, remaining,
+ " %ld", ratchet.skipped_message_keys[i].message_key.index
+ );
+ CHECK_SIZE_AND_ADVANCE;
+ }
+#undef CHECK_SIZE_AND_ADVANCE
+}
+
+namespace {
+// the master branch writes pickle version 1; the logging_enabled branch writes
+// 0x80000001.
+static const std::uint32_t SESSION_PICKLE_VERSION = 1;
+}
+
+std::size_t olm::pickle_length(
+ Session const & value
+) {
+ std::size_t length = 0;
+ length += olm::pickle_length(SESSION_PICKLE_VERSION);
+ length += olm::pickle_length(value.received_message);
+ length += olm::pickle_length(value.alice_identity_key);
+ length += olm::pickle_length(value.alice_base_key);
+ length += olm::pickle_length(value.bob_one_time_key);
+ length += olm::pickle_length(value.ratchet);
+ return length;
+}
+
+
+std::uint8_t * olm::pickle(
+ std::uint8_t * pos,
+ Session const & value
+) {
+ pos = olm::pickle(pos, SESSION_PICKLE_VERSION);
+ pos = olm::pickle(pos, value.received_message);
+ pos = olm::pickle(pos, value.alice_identity_key);
+ pos = olm::pickle(pos, value.alice_base_key);
+ pos = olm::pickle(pos, value.bob_one_time_key);
+ pos = olm::pickle(pos, value.ratchet);
+ return pos;
+}
+
+
+std::uint8_t const * olm::unpickle(
+ std::uint8_t const * pos, std::uint8_t const * end,
+ Session & value
+) {
+ uint32_t pickle_version;
+ pos = olm::unpickle(pos, end, pickle_version); UNPICKLE_OK(pos);
+
+ bool includes_chain_index;
+ switch (pickle_version) {
+ case 1:
+ includes_chain_index = false;
+ break;
+
+ case 0x80000001UL:
+ includes_chain_index = true;
+ break;
+
+ default:
+ value.last_error = OlmErrorCode::OLM_UNKNOWN_PICKLE_VERSION;
+ return nullptr;
+ }
+
+ pos = olm::unpickle(pos, end, value.received_message); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.alice_identity_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.alice_base_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.bob_one_time_key); UNPICKLE_OK(pos);
+ pos = olm::unpickle(pos, end, value.ratchet, includes_chain_index); UNPICKLE_OK(pos);
+
+ return pos;
+}
diff --git a/ext/olm/src/utility.cpp b/ext/olm/src/utility.cpp
new file mode 100644
index 0000000..b6bb56e
--- /dev/null
+++ b/ext/olm/src/utility.cpp
@@ -0,0 +1,57 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "olm/utility.hh"
+#include "olm/crypto.h"
+
+
+olm::Utility::Utility(
+) : last_error(OlmErrorCode::OLM_SUCCESS) {
+}
+
+
+size_t olm::Utility::sha256_length() const {
+ return SHA256_OUTPUT_LENGTH;
+}
+
+
+size_t olm::Utility::sha256(
+ std::uint8_t const * input, std::size_t input_length,
+ std::uint8_t * output, std::size_t output_length
+) {
+ if (output_length < sha256_length()) {
+ last_error = OlmErrorCode::OLM_OUTPUT_BUFFER_TOO_SMALL;
+ return std::size_t(-1);
+ }
+ _olm_crypto_sha256(input, input_length, output);
+ return SHA256_OUTPUT_LENGTH;
+}
+
+
+size_t olm::Utility::ed25519_verify(
+ _olm_ed25519_public_key const & key,
+ std::uint8_t const * message, std::size_t message_length,
+ std::uint8_t const * signature, std::size_t signature_length
+) {
+ if (signature_length < ED25519_SIGNATURE_LENGTH) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC;
+ return std::size_t(-1);
+ }
+ if (!_olm_crypto_ed25519_verify(&key, message, message_length, signature)) {
+ last_error = OlmErrorCode::OLM_BAD_MESSAGE_MAC;
+ return std::size_t(-1);
+ }
+ return std::size_t(0);
+}