1#![allow(non_camel_case_types, non_snake_case)]
2
3use libc::c_void;
4
5#[cfg(target_env = "msvc")]
6mod win {
7 use schannel::cert_context::ValidUses;
8 use schannel::cert_store::CertStore;
9 use std::ffi::*;
10 use std::mem;
11 use std::ptr;
12 use windows_sys::Win32::Security::Cryptography::*;
13 use windows_sys::Win32::System::LibraryLoader::*;
14
15 fn lookup(module: &str, symbol: &str) -> Option<*const c_void> {
16 unsafe {
17 let mut mod_buf: Vec<u16> = module.encode_utf16().collect();
18 mod_buf.push(0);
19 let handle = GetModuleHandleW(mod_buf.as_mut_ptr());
20 GetProcAddress(handle, symbol.as_ptr()).map(|n| n as *const c_void)
21 }
22 }
23
24 pub enum X509_STORE {}
25 pub enum X509 {}
26 pub enum SSL_CTX {}
27
28 type d2i_X509_fn = unsafe extern "C" fn(
29 a: *mut *mut X509,
30 pp: *mut *const c_uchar,
31 length: c_long,
32 ) -> *mut X509;
33 type X509_free_fn = unsafe extern "C" fn(x: *mut X509);
34 type X509_STORE_add_cert_fn =
35 unsafe extern "C" fn(store: *mut X509_STORE, x: *mut X509) -> c_int;
36 type SSL_CTX_get_cert_store_fn = unsafe extern "C" fn(ctx: *const SSL_CTX) -> *mut X509_STORE;
37
38 struct OpenSSL {
39 d2i_X509: d2i_X509_fn,
40 X509_free: X509_free_fn,
41 X509_STORE_add_cert: X509_STORE_add_cert_fn,
42 SSL_CTX_get_cert_store: SSL_CTX_get_cert_store_fn,
43 }
44
45 unsafe fn lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL> {
46 macro_rules! get {
47 ($(let $sym:ident in $module:expr;)*) => ($(
48 let $sym = match lookup($module, stringify!($sym)) {
49 Some(p) => p,
50 None => return None,
51 };
52 )*)
53 }
54 get! {
55 let d2i_X509 in crypto_module;
56 let X509_free in crypto_module;
57 let X509_STORE_add_cert in crypto_module;
58 let SSL_CTX_get_cert_store in ssl_module;
59 }
60 Some(OpenSSL {
61 d2i_X509: mem::transmute(d2i_X509),
62 X509_free: mem::transmute(X509_free),
63 X509_STORE_add_cert: mem::transmute(X509_STORE_add_cert),
64 SSL_CTX_get_cert_store: mem::transmute(SSL_CTX_get_cert_store),
65 })
66 }
67
68 pub unsafe fn add_certs_to_context(ssl_ctx: *mut c_void) {
69 // check the runtime version of OpenSSL
70 let openssl = match crate::version::Version::get().ssl_version() {
71 Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.1.0") => {
72 lookup_functions("libcrypto", "libssl")
73 }
74 Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.0.2") => {
75 lookup_functions("libeay32", "ssleay32")
76 }
77 _ => return,
78 };
79 let openssl = match openssl {
80 Some(s) => s,
81 None => return,
82 };
83
84 let openssl_store = (openssl.SSL_CTX_get_cert_store)(ssl_ctx as *const SSL_CTX);
85 let store = match CertStore::open_current_user("ROOT") {
86 Ok(s) => s,
87 Err(_) => return,
88 };
89
90 for cert in store.certs() {
91 let valid_uses = match cert.valid_uses() {
92 Ok(v) => v,
93 Err(_) => continue,
94 };
95
96 // check the extended key usage for the "Server Authentication" OID
97 match valid_uses {
98 ValidUses::All => {}
99 ValidUses::Oids(ref oids) => {
100 let oid = CStr::from_ptr(szOID_PKIX_KP_SERVER_AUTH as *const _)
101 .to_string_lossy()
102 .into_owned();
103 if !oids.contains(&oid) {
104 continue;
105 }
106 }
107 }
108
109 let der = cert.to_der();
110 let x509 = (openssl.d2i_X509)(ptr::null_mut(), &mut der.as_ptr(), der.len() as c_long);
111 if !x509.is_null() {
112 (openssl.X509_STORE_add_cert)(openssl_store, x509);
113 (openssl.X509_free)(x509);
114 }
115 }
116 }
117}
118
119#[cfg(target_env = "msvc")]
120pub fn add_certs_to_context(ssl_ctx: *mut c_void) {
121 unsafe {
122 win::add_certs_to_context(ssl_ctx as *mut _);
123 }
124}
125
126#[cfg(not(target_env = "msvc"))]
127pub fn add_certs_to_context(_: *mut c_void) {}
128