openMSX
SspiUtils.cc
Go to the documentation of this file.
1#ifdef _WIN32
2
3#include "SspiUtils.hh"
4
5#include "MSXException.hh"
6
7#include "xrange.hh"
8
9#include <sddl.h>
10
11#include <bit>
12#include <cassert>
13#include <iostream>
14
15//
16// NOTE: This file MUST be kept in sync between the openmsx and openmsx-debugger projects
17//
18
19namespace openmsx::sspiutils {
20
21SspiPackageBase::SspiPackageBase(StreamWrapper& userStream, const SEC_WCHAR* securityPackage)
22 : stream(userStream)
23 , cbMaxTokenSize(GetPackageMaxTokenSize(securityPackage))
24{
25 memset(&hCreds, 0, sizeof(hCreds));
26 memset(&hContext, 0, sizeof(hContext));
27
28 if (!cbMaxTokenSize) {
29 throw MSXException("GetPackageMaxTokenSize failed");
30 }
31}
32
33SspiPackageBase::~SspiPackageBase()
34{
35 DeleteSecurityContext(&hContext);
36 FreeCredentialsHandle(&hCreds);
37}
38
39void InitTokenContextBuffer(PSecBufferDesc pSecBufferDesc, PSecBuffer pSecBuffer)
40{
41 pSecBuffer->BufferType = SECBUFFER_TOKEN;
42 pSecBuffer->cbBuffer = 0;
43 pSecBuffer->pvBuffer = nullptr;
44
45 pSecBufferDesc->ulVersion = SECBUFFER_VERSION;
46 pSecBufferDesc->cBuffers = 1;
47 pSecBufferDesc->pBuffers = pSecBuffer;
48}
49
50void ClearContextBuffers(PSecBufferDesc pSecBufferDesc)
51{
52 for (auto i : xrange(pSecBufferDesc->cBuffers)) {
53 FreeContextBuffer(pSecBufferDesc->pBuffers[i].pvBuffer);
54 pSecBufferDesc->pBuffers[i].cbBuffer = 0;
55 pSecBufferDesc->pBuffers[i].pvBuffer = nullptr;
56 }
57}
58
59void DebugPrintSecurityStatus(const char* context, SECURITY_STATUS ss)
60{
61 (void)&context;
62 (void)&ss;
63#if 0
64 switch (ss) {
65 case SEC_E_OK:
66 std::cerr << context << ": SEC_E_OK\n";
67 break;
68 case SEC_I_CONTINUE_NEEDED:
69 std::cerr << context << ": SEC_I_CONTINUE_NEEDED\n";
70 break;
71 case SEC_E_INVALID_TOKEN:
72 std::cerr << context << ": SEC_E_INVALID_TOKEN\n";
73 break;
74 case SEC_E_BUFFER_TOO_SMALL:
75 std::cerr << context << ": SEC_E_BUFFER_TOO_SMALL\n";
76 break;
77 case SEC_E_INVALID_HANDLE:
78 std::cerr << context << ": SEC_E_INVALID_HANDLE\n";
79 break;
80 case SEC_E_WRONG_PRINCIPAL:
81 std::cerr << context << ": SEC_E_WRONG_PRINCIPAL\n";
82 break;
83 default:
84 std::cerr << context << ": " << ss << '\n';
85 break;
86 }
87#endif
88}
89
90void DebugPrintSecurityBool(const char* context, BOOL ret)
91{
92 (void)&context;
93 (void)&ret;
94#if 0
95 if (ret) {
96 std::cerr << context << ": true\n";
97 } else {
98 std::cerr << context << ": false - " << GetLastError() << '\n';
99 }
100#endif
101}
102
103void DebugPrintSecurityPackageName(PCtxtHandle phContext)
104{
105 (void)&phContext;
106#if 0
107 SecPkgContext_PackageInfoA package;
108 SECURITY_STATUS ss = QueryContextAttributesA(phContext, SECPKG_ATTR_PACKAGE_INFO, &package);
109 if (ss == SEC_E_OK) {
110 std::cerr << "Using " << package.PackageInfo->Name << " package\n";
111 }
112#endif
113}
114
115void DebugPrintSecurityPrincipalName(PCtxtHandle phContext)
116{
117 (void)&phContext;
118#if 0
119 SecPkgContext_NamesA name;
120 SECURITY_STATUS ss = QueryContextAttributesA(phContext, SECPKG_ATTR_NAMES, &name);
121 if (ss == SEC_E_OK) {
122 std::cerr << "Client principal " << name.sUserName << '\n';
123 }
124#endif
125}
126
127void DebugPrintSecurityDescriptor(PSECURITY_DESCRIPTOR psd)
128{
129 (void)&psd;
130#if 0
131 char* sddl;
132 BOOL ret = ConvertSecurityDescriptorToStringSecurityDescriptorA(
133 psd,
134 SDDL_REVISION,
135 OWNER_SECURITY_INFORMATION | GROUP_SECURITY_INFORMATION |
136 DACL_SECURITY_INFORMATION | SACL_SECURITY_INFORMATION | LABEL_SECURITY_INFORMATION,
137 &sddl,
138 nullptr);
139 if (ret) {
140 std::cerr << "SecurityDescriptor: " << sddl << '\n';
141 LocalFree(sddl);
142 }
143#endif
144}
145
146// If successful, caller must free the results with LocalFree()
147// If unsuccessful, returns null
148static PTOKEN_USER GetProcessToken()
149{
150 PTOKEN_USER pToken = nullptr;
151
152 HANDLE hProcessToken;
153 BOOL ret = OpenProcessToken(GetCurrentProcess(), TOKEN_READ, &hProcessToken);
154 DebugPrintSecurityBool("OpenProcessToken", ret);
155 if (ret) {
156 DWORD cbToken;
157 ret = GetTokenInformation(hProcessToken, TokenUser, nullptr, 0, &cbToken);
158 assert(!ret && GetLastError() == ERROR_INSUFFICIENT_BUFFER && cbToken);
159
160 pToken = static_cast<TOKEN_USER*>(LocalAlloc(LMEM_ZEROINIT, cbToken));
161 if (pToken) {
162 ret = GetTokenInformation(hProcessToken, TokenUser, pToken, cbToken, &cbToken);
163 DebugPrintSecurityBool("GetTokenInformation", ret);
164 if (!ret) {
165 LocalFree(pToken);
166 pToken = nullptr;
167 }
168 }
169 CloseHandle(hProcessToken);
170 }
171 return pToken;
172}
173
174// If successful, caller must free the results with LocalFree()
175// If unsuccessful, returns null
176PSECURITY_DESCRIPTOR CreateCurrentUserSecurityDescriptor()
177{
178 PSECURITY_DESCRIPTOR psd = nullptr;
179 PTOKEN_USER pToken = GetProcessToken();
180 if (pToken) {
181 PSID pUserSid = pToken->User.Sid;
182 const DWORD cbEachAce = sizeof(ACCESS_ALLOWED_ACE) - sizeof(DWORD);
183 const DWORD cbACL = sizeof(ACL) + cbEachAce + GetLengthSid(pUserSid);
184
185 // Allocate the SD and the ACL in one allocation, so we only have one buffer to manage
186 // The SD structure ends with a pointer, so the start of the ACL will be well aligned
187 BYTE* buffer = static_cast<BYTE*>(LocalAlloc(LMEM_ZEROINIT, SECURITY_DESCRIPTOR_MIN_LENGTH + cbACL));
188 if (buffer) {
189 psd = static_cast<PSECURITY_DESCRIPTOR>(buffer);
190 PACL pacl = reinterpret_cast<PACL>(buffer + SECURITY_DESCRIPTOR_MIN_LENGTH);
191 PACCESS_ALLOWED_ACE pUserAce;
192 if (InitializeSecurityDescriptor(psd, SECURITY_DESCRIPTOR_REVISION) &&
193 InitializeAcl(pacl, cbACL, ACL_REVISION) &&
194 AddAccessAllowedAce(pacl, ACL_REVISION, ACCESS_ALL, pUserSid) &&
195 SetSecurityDescriptorDacl(psd, TRUE, pacl, FALSE) &&
196 // Need to set the Group and Owner on the SD in order to use it with AccessCheck()
197 GetAce(pacl, 0, std::bit_cast<void**>(&pUserAce)) &&
198 SetSecurityDescriptorGroup(psd, &pUserAce->SidStart, FALSE) &&
199 SetSecurityDescriptorOwner(psd, &pUserAce->SidStart, FALSE)) {
200 buffer = nullptr;
201 } else {
202 psd = nullptr;
203 }
204 LocalFree(buffer);
205 }
206 LocalFree(pToken);
207 }
208
209 if (psd) {
210 assert(IsValidSecurityDescriptor(psd));
211 DebugPrintSecurityDescriptor(psd);
212 }
213 return psd;
214}
215
216unsigned long GetPackageMaxTokenSize(const SEC_WCHAR* package)
217{
218 PSecPkgInfoW pkgInfo;
219 SECURITY_STATUS ss = QuerySecurityPackageInfoW(const_cast<SEC_WCHAR*>(package), &pkgInfo);
220 DebugPrintSecurityStatus("QuerySecurityPackageInfoW", ss);
221 if (ss != SEC_E_OK) return 0;
222
223 unsigned long cbMaxToken = pkgInfo->cbMaxToken;
224 FreeContextBuffer(pkgInfo);
225 return cbMaxToken;
226}
227
228static bool Send(StreamWrapper& stream, void* buffer, uint32_t cb)
229{
230 uint32_t sent = 0;
231 while (sent < cb) {
232 uint32_t ret = stream.Write(static_cast<char*>(buffer) + sent, cb - sent);
233 if (ret == STREAM_ERROR) return false;
234 sent += ret;
235 }
236 return true;
237}
238
239bool SendChunk(StreamWrapper& stream, void* buffer, uint32_t cb)
240{
241 uint32_t nl = htonl(cb);
242 if (!Send(stream, &nl, sizeof(nl))) {
243 return false;
244 }
245 return Send(stream, buffer, cb);
246}
247
248static bool Recv(StreamWrapper& stream, void* buffer, uint32_t cb)
249{
250 uint32_t recvd = 0;
251 while (recvd < cb) {
252 uint32_t ret = stream.Read(static_cast<char*>(buffer) + recvd, cb - recvd);
253 if (ret == STREAM_ERROR) return false;
254 recvd += ret;
255 }
256 return true;
257}
258
259static bool RecvChunkSize(StreamWrapper& stream, uint32_t* pcb)
260{
261 uint32_t cb;
262 bool ret = Recv(stream, &cb, sizeof(cb));
263 if (ret) {
264 *pcb = ntohl(cb);
265 }
266 return ret;
267}
268
269bool RecvChunk(StreamWrapper& stream, std::vector<char>& buffer, uint32_t cbMaxSize)
270{
271 uint32_t cb;
272 if (!RecvChunkSize(stream, &cb) || cb > cbMaxSize) {
273 return false;
274 }
275 buffer.resize(cb);
276 if (!Recv(stream, &buffer[0], cb)) {
277 return false;
278 }
279 return true;
280}
281
282} // namespace openmsx::sspiutils
283
284#endif
std::optional< Context > context
Definition GLContext.cc:10
constexpr auto xrange(T e)
Definition xrange.hh:132