GNU Linux-libre 4.19.264-gnu1
[releases.git] / tools / testing / selftests / x86 / protection_keys.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Tests x86 Memory Protection Keys (see Documentation/x86/protection-keys.txt)
4  *
5  * There are examples in here of:
6  *  * how to set protection keys on memory
7  *  * how to set/clear bits in PKRU (the rights register)
8  *  * how to handle SEGV_PKRU signals and extract pkey-relevant
9  *    information from the siginfo
10  *
11  * Things to add:
12  *      make sure KSM and KSM COW breaking works
13  *      prefault pages in at malloc, or not
14  *      protect MPX bounds tables with protection keys?
15  *      make sure VMA splitting/merging is working correctly
16  *      OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
17  *      look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
18  *      do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
19  *
20  * Compile like this:
21  *      gcc      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
22  *      gcc -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
23  */
24 #define _GNU_SOURCE
25 #include <errno.h>
26 #include <linux/futex.h>
27 #include <time.h>
28 #include <sys/time.h>
29 #include <sys/syscall.h>
30 #include <string.h>
31 #include <stdio.h>
32 #include <stdint.h>
33 #include <stdbool.h>
34 #include <signal.h>
35 #include <assert.h>
36 #include <stdlib.h>
37 #include <ucontext.h>
38 #include <sys/mman.h>
39 #include <sys/types.h>
40 #include <sys/wait.h>
41 #include <sys/stat.h>
42 #include <fcntl.h>
43 #include <unistd.h>
44 #include <sys/ptrace.h>
45 #include <setjmp.h>
46
47 #include "pkey-helpers.h"
48
49 int iteration_nr = 1;
50 int test_nr;
51
52 unsigned int shadow_pkru;
53
54 #define HPAGE_SIZE      (1UL<<21)
55 #define ARRAY_SIZE(x) (sizeof(x) / sizeof(*(x)))
56 #define ALIGN_UP(x, align_to)   (((x) + ((align_to)-1)) & ~((align_to)-1))
57 #define ALIGN_DOWN(x, align_to) ((x) & ~((align_to)-1))
58 #define ALIGN_PTR_UP(p, ptr_align_to)   ((typeof(p))ALIGN_UP((unsigned long)(p),        ptr_align_to))
59 #define ALIGN_PTR_DOWN(p, ptr_align_to) ((typeof(p))ALIGN_DOWN((unsigned long)(p),      ptr_align_to))
60 #define __stringify_1(x...)     #x
61 #define __stringify(x...)       __stringify_1(x)
62
63 #define PTR_ERR_ENOTSUP ((void *)-ENOTSUP)
64
65 int dprint_in_signal;
66 char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
67
68 extern void abort_hooks(void);
69 #define pkey_assert(condition) do {             \
70         if (!(condition)) {                     \
71                 dprintf0("assert() at %s::%d test_nr: %d iteration: %d\n", \
72                                 __FILE__, __LINE__,     \
73                                 test_nr, iteration_nr); \
74                 dprintf0("errno at assert: %d", errno); \
75                 abort_hooks();                  \
76                 exit(__LINE__);                 \
77         }                                       \
78 } while (0)
79
80 void cat_into_file(char *str, char *file)
81 {
82         int fd = open(file, O_RDWR);
83         int ret;
84
85         dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
86         /*
87          * these need to be raw because they are called under
88          * pkey_assert()
89          */
90         if (fd < 0) {
91                 fprintf(stderr, "error opening '%s'\n", str);
92                 perror("error: ");
93                 exit(__LINE__);
94         }
95
96         ret = write(fd, str, strlen(str));
97         if (ret != strlen(str)) {
98                 perror("write to file failed");
99                 fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
100                 exit(__LINE__);
101         }
102         close(fd);
103 }
104
105 #if CONTROL_TRACING > 0
106 static int warned_tracing;
107 int tracing_root_ok(void)
108 {
109         if (geteuid() != 0) {
110                 if (!warned_tracing)
111                         fprintf(stderr, "WARNING: not run as root, "
112                                         "can not do tracing control\n");
113                 warned_tracing = 1;
114                 return 0;
115         }
116         return 1;
117 }
118 #endif
119
120 void tracing_on(void)
121 {
122 #if CONTROL_TRACING > 0
123 #define TRACEDIR "/sys/kernel/debug/tracing"
124         char pidstr[32];
125
126         if (!tracing_root_ok())
127                 return;
128
129         sprintf(pidstr, "%d", getpid());
130         cat_into_file("0", TRACEDIR "/tracing_on");
131         cat_into_file("\n", TRACEDIR "/trace");
132         if (1) {
133                 cat_into_file("function_graph", TRACEDIR "/current_tracer");
134                 cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
135         } else {
136                 cat_into_file("nop", TRACEDIR "/current_tracer");
137         }
138         cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
139         cat_into_file("1", TRACEDIR "/tracing_on");
140         dprintf1("enabled tracing\n");
141 #endif
142 }
143
144 void tracing_off(void)
145 {
146 #if CONTROL_TRACING > 0
147         if (!tracing_root_ok())
148                 return;
149         cat_into_file("0", "/sys/kernel/debug/tracing/tracing_on");
150 #endif
151 }
152
153 void abort_hooks(void)
154 {
155         fprintf(stderr, "running %s()...\n", __func__);
156         tracing_off();
157 #ifdef SLEEP_ON_ABORT
158         sleep(SLEEP_ON_ABORT);
159 #endif
160 }
161
162 static inline void __page_o_noops(void)
163 {
164         /* 8-bytes of instruction * 512 bytes = 1 page */
165         asm(".rept 512 ; nopl 0x7eeeeeee(%eax) ; .endr");
166 }
167
168 /*
169  * This attempts to have roughly a page of instructions followed by a few
170  * instructions that do a write, and another page of instructions.  That
171  * way, we are pretty sure that the write is in the second page of
172  * instructions and has at least a page of padding behind it.
173  *
174  * *That* lets us be sure to madvise() away the write instruction, which
175  * will then fault, which makes sure that the fault code handles
176  * execute-only memory properly.
177  */
178 __attribute__((__aligned__(PAGE_SIZE)))
179 void lots_o_noops_around_write(int *write_to_me)
180 {
181         dprintf3("running %s()\n", __func__);
182         __page_o_noops();
183         /* Assume this happens in the second page of instructions: */
184         *write_to_me = __LINE__;
185         /* pad out by another page: */
186         __page_o_noops();
187         dprintf3("%s() done\n", __func__);
188 }
189
190 /* Define some kernel-like types */
191 #define  u8 uint8_t
192 #define u16 uint16_t
193 #define u32 uint32_t
194 #define u64 uint64_t
195
196 #ifdef __i386__
197
198 #ifndef SYS_mprotect_key
199 # define SYS_mprotect_key       380
200 #endif
201
202 #ifndef SYS_pkey_alloc
203 # define SYS_pkey_alloc         381
204 # define SYS_pkey_free          382
205 #endif
206
207 #define REG_IP_IDX              REG_EIP
208 #define si_pkey_offset          0x14
209
210 #else
211
212 #ifndef SYS_mprotect_key
213 # define SYS_mprotect_key       329
214 #endif
215
216 #ifndef SYS_pkey_alloc
217 # define SYS_pkey_alloc         330
218 # define SYS_pkey_free          331
219 #endif
220
221 #define REG_IP_IDX              REG_RIP
222 #define si_pkey_offset          0x20
223
224 #endif
225
226 void dump_mem(void *dumpme, int len_bytes)
227 {
228         char *c = (void *)dumpme;
229         int i;
230
231         for (i = 0; i < len_bytes; i += sizeof(u64)) {
232                 u64 *ptr = (u64 *)(c + i);
233                 dprintf1("dump[%03d][@%p]: %016jx\n", i, ptr, *ptr);
234         }
235 }
236
237 /* Failed address bound checks: */
238 #ifndef SEGV_BNDERR
239 # define SEGV_BNDERR            3
240 #endif
241
242 #ifndef SEGV_PKUERR
243 # define SEGV_PKUERR            4
244 #endif
245
246 static char *si_code_str(int si_code)
247 {
248         if (si_code == SEGV_MAPERR)
249                 return "SEGV_MAPERR";
250         if (si_code == SEGV_ACCERR)
251                 return "SEGV_ACCERR";
252         if (si_code == SEGV_BNDERR)
253                 return "SEGV_BNDERR";
254         if (si_code == SEGV_PKUERR)
255                 return "SEGV_PKUERR";
256         return "UNKNOWN";
257 }
258
259 int pkru_faults;
260 int last_si_pkey = -1;
261 void signal_handler(int signum, siginfo_t *si, void *vucontext)
262 {
263         ucontext_t *uctxt = vucontext;
264         int trapno;
265         unsigned long ip;
266         char *fpregs;
267         u32 *pkru_ptr;
268         u64 siginfo_pkey;
269         u32 *si_pkey_ptr;
270         int pkru_offset;
271         fpregset_t fpregset;
272
273         dprint_in_signal = 1;
274         dprintf1(">>>>===============SIGSEGV============================\n");
275         dprintf1("%s()::%d, pkru: 0x%x shadow: %x\n", __func__, __LINE__,
276                         __rdpkru(), shadow_pkru);
277
278         trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
279         ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
280         fpregset = uctxt->uc_mcontext.fpregs;
281         fpregs = (void *)fpregset;
282
283         dprintf2("%s() trapno: %d ip: 0x%lx info->si_code: %s/%d\n", __func__,
284                         trapno, ip, si_code_str(si->si_code), si->si_code);
285 #ifdef __i386__
286         /*
287          * 32-bit has some extra padding so that userspace can tell whether
288          * the XSTATE header is present in addition to the "legacy" FPU
289          * state.  We just assume that it is here.
290          */
291         fpregs += 0x70;
292 #endif
293         pkru_offset = pkru_xstate_offset();
294         pkru_ptr = (void *)(&fpregs[pkru_offset]);
295
296         dprintf1("siginfo: %p\n", si);
297         dprintf1(" fpregs: %p\n", fpregs);
298         /*
299          * If we got a PKRU fault, we *HAVE* to have at least one bit set in
300          * here.
301          */
302         dprintf1("pkru_xstate_offset: %d\n", pkru_xstate_offset());
303         if (DEBUG_LEVEL > 4)
304                 dump_mem(pkru_ptr - 128, 256);
305         pkey_assert(*pkru_ptr);
306
307         if ((si->si_code == SEGV_MAPERR) ||
308             (si->si_code == SEGV_ACCERR) ||
309             (si->si_code == SEGV_BNDERR)) {
310                 printf("non-PK si_code, exiting...\n");
311                 exit(4);
312         }
313
314         si_pkey_ptr = (u32 *)(((u8 *)si) + si_pkey_offset);
315         dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
316         dump_mem((u8 *)si_pkey_ptr - 8, 24);
317         siginfo_pkey = *si_pkey_ptr;
318         pkey_assert(siginfo_pkey < NR_PKEYS);
319         last_si_pkey = siginfo_pkey;
320
321         dprintf1("signal pkru from xsave: %08x\n", *pkru_ptr);
322         /* need __rdpkru() version so we do not do shadow_pkru checking */
323         dprintf1("signal pkru from  pkru: %08x\n", __rdpkru());
324         dprintf1("pkey from siginfo: %jx\n", siginfo_pkey);
325         *(u64 *)pkru_ptr = 0x00000000;
326         dprintf1("WARNING: set PRKU=0 to allow faulting instruction to continue\n");
327         pkru_faults++;
328         dprintf1("<<<<==================================================\n");
329         dprint_in_signal = 0;
330 }
331
332 int wait_all_children(void)
333 {
334         int status;
335         return waitpid(-1, &status, 0);
336 }
337
338 void sig_chld(int x)
339 {
340         dprint_in_signal = 1;
341         dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
342         dprint_in_signal = 0;
343 }
344
345 void setup_sigsegv_handler(void)
346 {
347         int r, rs;
348         struct sigaction newact;
349         struct sigaction oldact;
350
351         /* #PF is mapped to sigsegv */
352         int signum  = SIGSEGV;
353
354         newact.sa_handler = 0;
355         newact.sa_sigaction = signal_handler;
356
357         /*sigset_t - signals to block while in the handler */
358         /* get the old signal mask. */
359         rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
360         pkey_assert(rs == 0);
361
362         /* call sa_sigaction, not sa_handler*/
363         newact.sa_flags = SA_SIGINFO;
364
365         newact.sa_restorer = 0;  /* void(*)(), obsolete */
366         r = sigaction(signum, &newact, &oldact);
367         r = sigaction(SIGALRM, &newact, &oldact);
368         pkey_assert(r == 0);
369 }
370
371 void setup_handlers(void)
372 {
373         signal(SIGCHLD, &sig_chld);
374         setup_sigsegv_handler();
375 }
376
377 pid_t fork_lazy_child(void)
378 {
379         pid_t forkret;
380
381         forkret = fork();
382         pkey_assert(forkret >= 0);
383         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
384
385         if (!forkret) {
386                 /* in the child */
387                 while (1) {
388                         dprintf1("child sleeping...\n");
389                         sleep(30);
390                 }
391         }
392         return forkret;
393 }
394
395 #ifndef PKEY_DISABLE_ACCESS
396 # define PKEY_DISABLE_ACCESS    0x1
397 #endif
398
399 #ifndef PKEY_DISABLE_WRITE
400 # define PKEY_DISABLE_WRITE     0x2
401 #endif
402
403 static u32 hw_pkey_get(int pkey, unsigned long flags)
404 {
405         u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
406         u32 pkru = __rdpkru();
407         u32 shifted_pkru;
408         u32 masked_pkru;
409
410         dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
411                         __func__, pkey, flags, 0, 0);
412         dprintf2("%s() raw pkru: %x\n", __func__, pkru);
413
414         shifted_pkru = (pkru >> (pkey * PKRU_BITS_PER_PKEY));
415         dprintf2("%s() shifted_pkru: %x\n", __func__, shifted_pkru);
416         masked_pkru = shifted_pkru & mask;
417         dprintf2("%s() masked  pkru: %x\n", __func__, masked_pkru);
418         /*
419          * shift down the relevant bits to the lowest two, then
420          * mask off all the other high bits.
421          */
422         return masked_pkru;
423 }
424
425 static int hw_pkey_set(int pkey, unsigned long rights, unsigned long flags)
426 {
427         u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
428         u32 old_pkru = __rdpkru();
429         u32 new_pkru;
430
431         /* make sure that 'rights' only contains the bits we expect: */
432         assert(!(rights & ~mask));
433
434         /* copy old pkru */
435         new_pkru = old_pkru;
436         /* mask out bits from pkey in old value: */
437         new_pkru &= ~(mask << (pkey * PKRU_BITS_PER_PKEY));
438         /* OR in new bits for pkey: */
439         new_pkru |= (rights << (pkey * PKRU_BITS_PER_PKEY));
440
441         __wrpkru(new_pkru);
442
443         dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x pkru now: %x old_pkru: %x\n",
444                         __func__, pkey, rights, flags, 0, __rdpkru(), old_pkru);
445         return 0;
446 }
447
448 void pkey_disable_set(int pkey, int flags)
449 {
450         unsigned long syscall_flags = 0;
451         int ret;
452         int pkey_rights;
453         u32 orig_pkru = rdpkru();
454
455         dprintf1("START->%s(%d, 0x%x)\n", __func__,
456                 pkey, flags);
457         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
458
459         pkey_rights = hw_pkey_get(pkey, syscall_flags);
460
461         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
462                         pkey, pkey, pkey_rights);
463         pkey_assert(pkey_rights >= 0);
464
465         pkey_rights |= flags;
466
467         ret = hw_pkey_set(pkey, pkey_rights, syscall_flags);
468         assert(!ret);
469         /*pkru and flags have the same format */
470         shadow_pkru |= flags << (pkey * 2);
471         dprintf1("%s(%d) shadow: 0x%x\n", __func__, pkey, shadow_pkru);
472
473         pkey_assert(ret >= 0);
474
475         pkey_rights = hw_pkey_get(pkey, syscall_flags);
476         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
477                         pkey, pkey, pkey_rights);
478
479         dprintf1("%s(%d) pkru: 0x%x\n", __func__, pkey, rdpkru());
480         if (flags)
481                 pkey_assert(rdpkru() > orig_pkru);
482         dprintf1("END<---%s(%d, 0x%x)\n", __func__,
483                 pkey, flags);
484 }
485
486 void pkey_disable_clear(int pkey, int flags)
487 {
488         unsigned long syscall_flags = 0;
489         int ret;
490         int pkey_rights = hw_pkey_get(pkey, syscall_flags);
491         u32 orig_pkru = rdpkru();
492
493         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
494
495         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
496                         pkey, pkey, pkey_rights);
497         pkey_assert(pkey_rights >= 0);
498
499         pkey_rights |= flags;
500
501         ret = hw_pkey_set(pkey, pkey_rights, 0);
502         /* pkru and flags have the same format */
503         shadow_pkru &= ~(flags << (pkey * 2));
504         pkey_assert(ret >= 0);
505
506         pkey_rights = hw_pkey_get(pkey, syscall_flags);
507         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
508                         pkey, pkey, pkey_rights);
509
510         dprintf1("%s(%d) pkru: 0x%x\n", __func__, pkey, rdpkru());
511         if (flags)
512                 assert(rdpkru() > orig_pkru);
513 }
514
515 void pkey_write_allow(int pkey)
516 {
517         pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
518 }
519 void pkey_write_deny(int pkey)
520 {
521         pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
522 }
523 void pkey_access_allow(int pkey)
524 {
525         pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
526 }
527 void pkey_access_deny(int pkey)
528 {
529         pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
530 }
531
532 int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
533                 unsigned long pkey)
534 {
535         int sret;
536
537         dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
538                         ptr, size, orig_prot, pkey);
539
540         errno = 0;
541         sret = syscall(SYS_mprotect_key, ptr, size, orig_prot, pkey);
542         if (errno) {
543                 dprintf2("SYS_mprotect_key sret: %d\n", sret);
544                 dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
545                 dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
546                 if (DEBUG_LEVEL >= 2)
547                         perror("SYS_mprotect_pkey");
548         }
549         return sret;
550 }
551
552 int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
553 {
554         int ret = syscall(SYS_pkey_alloc, flags, init_val);
555         dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
556                         __func__, flags, init_val, ret, errno);
557         return ret;
558 }
559
560 int alloc_pkey(void)
561 {
562         int ret;
563         unsigned long init_val = 0x0;
564
565         dprintf1("alloc_pkey()::%d, pkru: 0x%x shadow: %x\n",
566                         __LINE__, __rdpkru(), shadow_pkru);
567         ret = sys_pkey_alloc(0, init_val);
568         /*
569          * pkey_alloc() sets PKRU, so we need to reflect it in
570          * shadow_pkru:
571          */
572         dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
573                         __LINE__, ret, __rdpkru(), shadow_pkru);
574         if (ret) {
575                 /* clear both the bits: */
576                 shadow_pkru &= ~(0x3      << (ret * 2));
577                 dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
578                                 __LINE__, ret, __rdpkru(), shadow_pkru);
579                 /*
580                  * move the new state in from init_val
581                  * (remember, we cheated and init_val == pkru format)
582                  */
583                 shadow_pkru |=  (init_val << (ret * 2));
584         }
585         dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
586                         __LINE__, ret, __rdpkru(), shadow_pkru);
587         dprintf1("alloc_pkey()::%d errno: %d\n", __LINE__, errno);
588         /* for shadow checking: */
589         rdpkru();
590         dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
591                         __LINE__, ret, __rdpkru(), shadow_pkru);
592         return ret;
593 }
594
595 int sys_pkey_free(unsigned long pkey)
596 {
597         int ret = syscall(SYS_pkey_free, pkey);
598         dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
599         return ret;
600 }
601
602 /*
603  * I had a bug where pkey bits could be set by mprotect() but
604  * not cleared.  This ensures we get lots of random bit sets
605  * and clears on the vma and pte pkey bits.
606  */
607 int alloc_random_pkey(void)
608 {
609         int max_nr_pkey_allocs;
610         int ret;
611         int i;
612         int alloced_pkeys[NR_PKEYS];
613         int nr_alloced = 0;
614         int random_index;
615         memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
616
617         /* allocate every possible key and make a note of which ones we got */
618         max_nr_pkey_allocs = NR_PKEYS;
619         for (i = 0; i < max_nr_pkey_allocs; i++) {
620                 int new_pkey = alloc_pkey();
621                 if (new_pkey < 0)
622                         break;
623                 alloced_pkeys[nr_alloced++] = new_pkey;
624         }
625
626         pkey_assert(nr_alloced > 0);
627         /* select a random one out of the allocated ones */
628         random_index = rand() % nr_alloced;
629         ret = alloced_pkeys[random_index];
630         /* now zero it out so we don't free it next */
631         alloced_pkeys[random_index] = 0;
632
633         /* go through the allocated ones that we did not want and free them */
634         for (i = 0; i < nr_alloced; i++) {
635                 int free_ret;
636                 if (!alloced_pkeys[i])
637                         continue;
638                 free_ret = sys_pkey_free(alloced_pkeys[i]);
639                 pkey_assert(!free_ret);
640         }
641         dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
642                         __LINE__, ret, __rdpkru(), shadow_pkru);
643         return ret;
644 }
645
646 int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
647                 unsigned long pkey)
648 {
649         int nr_iterations = random() % 100;
650         int ret;
651
652         while (0) {
653                 int rpkey = alloc_random_pkey();
654                 ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
655                 dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
656                                 ptr, size, orig_prot, pkey, ret);
657                 if (nr_iterations-- < 0)
658                         break;
659
660                 dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
661                         __LINE__, ret, __rdpkru(), shadow_pkru);
662                 sys_pkey_free(rpkey);
663                 dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
664                         __LINE__, ret, __rdpkru(), shadow_pkru);
665         }
666         pkey_assert(pkey < NR_PKEYS);
667
668         ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
669         dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
670                         ptr, size, orig_prot, pkey, ret);
671         pkey_assert(!ret);
672         dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
673                         __LINE__, ret, __rdpkru(), shadow_pkru);
674         return ret;
675 }
676
677 struct pkey_malloc_record {
678         void *ptr;
679         long size;
680         int prot;
681 };
682 struct pkey_malloc_record *pkey_malloc_records;
683 struct pkey_malloc_record *pkey_last_malloc_record;
684 long nr_pkey_malloc_records;
685 void record_pkey_malloc(void *ptr, long size, int prot)
686 {
687         long i;
688         struct pkey_malloc_record *rec = NULL;
689
690         for (i = 0; i < nr_pkey_malloc_records; i++) {
691                 rec = &pkey_malloc_records[i];
692                 /* find a free record */
693                 if (rec)
694                         break;
695         }
696         if (!rec) {
697                 /* every record is full */
698                 size_t old_nr_records = nr_pkey_malloc_records;
699                 size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
700                 size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
701                 dprintf2("new_nr_records: %zd\n", new_nr_records);
702                 dprintf2("new_size: %zd\n", new_size);
703                 pkey_malloc_records = realloc(pkey_malloc_records, new_size);
704                 pkey_assert(pkey_malloc_records != NULL);
705                 rec = &pkey_malloc_records[nr_pkey_malloc_records];
706                 /*
707                  * realloc() does not initialize memory, so zero it from
708                  * the first new record all the way to the end.
709                  */
710                 for (i = 0; i < new_nr_records - old_nr_records; i++)
711                         memset(rec + i, 0, sizeof(*rec));
712         }
713         dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
714                 (int)(rec - pkey_malloc_records), rec, ptr, size);
715         rec->ptr = ptr;
716         rec->size = size;
717         rec->prot = prot;
718         pkey_last_malloc_record = rec;
719         nr_pkey_malloc_records++;
720 }
721
722 void free_pkey_malloc(void *ptr)
723 {
724         long i;
725         int ret;
726         dprintf3("%s(%p)\n", __func__, ptr);
727         for (i = 0; i < nr_pkey_malloc_records; i++) {
728                 struct pkey_malloc_record *rec = &pkey_malloc_records[i];
729                 dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
730                                 ptr, i, rec, rec->ptr, rec->size);
731                 if ((ptr <  rec->ptr) ||
732                     (ptr >= rec->ptr + rec->size))
733                         continue;
734
735                 dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
736                                 ptr, i, rec, rec->ptr, rec->size);
737                 nr_pkey_malloc_records--;
738                 ret = munmap(rec->ptr, rec->size);
739                 dprintf3("munmap ret: %d\n", ret);
740                 pkey_assert(!ret);
741                 dprintf3("clearing rec->ptr, rec: %p\n", rec);
742                 rec->ptr = NULL;
743                 dprintf3("done clearing rec->ptr, rec: %p\n", rec);
744                 return;
745         }
746         pkey_assert(false);
747 }
748
749
750 void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
751 {
752         void *ptr;
753         int ret;
754
755         rdpkru();
756         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
757                         size, prot, pkey);
758         pkey_assert(pkey < NR_PKEYS);
759         ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
760         pkey_assert(ptr != (void *)-1);
761         ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
762         pkey_assert(!ret);
763         record_pkey_malloc(ptr, size, prot);
764         rdpkru();
765
766         dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
767         return ptr;
768 }
769
770 void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
771 {
772         int ret;
773         void *ptr;
774
775         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
776                         size, prot, pkey);
777         /*
778          * Guarantee we can fit at least one huge page in the resulting
779          * allocation by allocating space for 2:
780          */
781         size = ALIGN_UP(size, HPAGE_SIZE * 2);
782         ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
783         pkey_assert(ptr != (void *)-1);
784         record_pkey_malloc(ptr, size, prot);
785         mprotect_pkey(ptr, size, prot, pkey);
786
787         dprintf1("unaligned ptr: %p\n", ptr);
788         ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
789         dprintf1("  aligned ptr: %p\n", ptr);
790         ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
791         dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
792         ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
793         dprintf1("MADV_WILLNEED ret: %d\n", ret);
794         memset(ptr, 0, HPAGE_SIZE);
795
796         dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
797         return ptr;
798 }
799
800 int hugetlb_setup_ok;
801 #define GET_NR_HUGE_PAGES 10
802 void setup_hugetlbfs(void)
803 {
804         int err;
805         int fd;
806         char buf[] = "123";
807
808         if (geteuid() != 0) {
809                 fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
810                 return;
811         }
812
813         cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
814
815         /*
816          * Now go make sure that we got the pages and that they
817          * are 2M pages.  Someone might have made 1G the default.
818          */
819         fd = open("/sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages", O_RDONLY);
820         if (fd < 0) {
821                 perror("opening sysfs 2M hugetlb config");
822                 return;
823         }
824
825         /* -1 to guarantee leaving the trailing \0 */
826         err = read(fd, buf, sizeof(buf)-1);
827         close(fd);
828         if (err <= 0) {
829                 perror("reading sysfs 2M hugetlb config");
830                 return;
831         }
832
833         if (atoi(buf) != GET_NR_HUGE_PAGES) {
834                 fprintf(stderr, "could not confirm 2M pages, got: '%s' expected %d\n",
835                         buf, GET_NR_HUGE_PAGES);
836                 return;
837         }
838
839         hugetlb_setup_ok = 1;
840 }
841
842 void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
843 {
844         void *ptr;
845         int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
846
847         if (!hugetlb_setup_ok)
848                 return PTR_ERR_ENOTSUP;
849
850         dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
851         size = ALIGN_UP(size, HPAGE_SIZE * 2);
852         pkey_assert(pkey < NR_PKEYS);
853         ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
854         pkey_assert(ptr != (void *)-1);
855         mprotect_pkey(ptr, size, prot, pkey);
856
857         record_pkey_malloc(ptr, size, prot);
858
859         dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
860         return ptr;
861 }
862
863 void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
864 {
865         void *ptr;
866         int fd;
867
868         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
869                         size, prot, pkey);
870         pkey_assert(pkey < NR_PKEYS);
871         fd = open("/dax/foo", O_RDWR);
872         pkey_assert(fd >= 0);
873
874         ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
875         pkey_assert(ptr != (void *)-1);
876
877         mprotect_pkey(ptr, size, prot, pkey);
878
879         record_pkey_malloc(ptr, size, prot);
880
881         dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
882         close(fd);
883         return ptr;
884 }
885
886 void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
887
888         malloc_pkey_with_mprotect,
889         malloc_pkey_anon_huge,
890         malloc_pkey_hugetlb
891 /* can not do direct with the pkey_mprotect() API:
892         malloc_pkey_mmap_direct,
893         malloc_pkey_mmap_dax,
894 */
895 };
896
897 void *malloc_pkey(long size, int prot, u16 pkey)
898 {
899         void *ret;
900         static int malloc_type;
901         int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
902
903         pkey_assert(pkey < NR_PKEYS);
904
905         while (1) {
906                 pkey_assert(malloc_type < nr_malloc_types);
907
908                 ret = pkey_malloc[malloc_type](size, prot, pkey);
909                 pkey_assert(ret != (void *)-1);
910
911                 malloc_type++;
912                 if (malloc_type >= nr_malloc_types)
913                         malloc_type = (random()%nr_malloc_types);
914
915                 /* try again if the malloc_type we tried is unsupported */
916                 if (ret == PTR_ERR_ENOTSUP)
917                         continue;
918
919                 break;
920         }
921
922         dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
923                         size, prot, pkey, ret);
924         return ret;
925 }
926
927 int last_pkru_faults;
928 #define UNKNOWN_PKEY -2
929 void expected_pk_fault(int pkey)
930 {
931         dprintf2("%s(): last_pkru_faults: %d pkru_faults: %d\n",
932                         __func__, last_pkru_faults, pkru_faults);
933         dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
934         pkey_assert(last_pkru_faults + 1 == pkru_faults);
935
936        /*
937         * For exec-only memory, we do not know the pkey in
938         * advance, so skip this check.
939         */
940         if (pkey != UNKNOWN_PKEY)
941                 pkey_assert(last_si_pkey == pkey);
942
943         /*
944          * The signal handler shold have cleared out PKRU to let the
945          * test program continue.  We now have to restore it.
946          */
947         if (__rdpkru() != 0)
948                 pkey_assert(0);
949
950         __wrpkru(shadow_pkru);
951         dprintf1("%s() set PKRU=%x to restore state after signal nuked it\n",
952                         __func__, shadow_pkru);
953         last_pkru_faults = pkru_faults;
954         last_si_pkey = -1;
955 }
956
957 #define do_not_expect_pk_fault(msg)     do {                    \
958         if (last_pkru_faults != pkru_faults)                    \
959                 dprintf0("unexpected PK fault: %s\n", msg);     \
960         pkey_assert(last_pkru_faults == pkru_faults);           \
961 } while (0)
962
963 int test_fds[10] = { -1 };
964 int nr_test_fds;
965 void __save_test_fd(int fd)
966 {
967         pkey_assert(fd >= 0);
968         pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
969         test_fds[nr_test_fds] = fd;
970         nr_test_fds++;
971 }
972
973 int get_test_read_fd(void)
974 {
975         int test_fd = open("/etc/passwd", O_RDONLY);
976         __save_test_fd(test_fd);
977         return test_fd;
978 }
979
980 void close_test_fds(void)
981 {
982         int i;
983
984         for (i = 0; i < nr_test_fds; i++) {
985                 if (test_fds[i] < 0)
986                         continue;
987                 close(test_fds[i]);
988                 test_fds[i] = -1;
989         }
990         nr_test_fds = 0;
991 }
992
993 #define barrier() __asm__ __volatile__("": : :"memory")
994 __attribute__((noinline)) int read_ptr(int *ptr)
995 {
996         /*
997          * Keep GCC from optimizing this away somehow
998          */
999         barrier();
1000         return *ptr;
1001 }
1002
1003 void test_read_of_write_disabled_region(int *ptr, u16 pkey)
1004 {
1005         int ptr_contents;
1006
1007         dprintf1("disabling write access to PKEY[1], doing read\n");
1008         pkey_write_deny(pkey);
1009         ptr_contents = read_ptr(ptr);
1010         dprintf1("*ptr: %d\n", ptr_contents);
1011         dprintf1("\n");
1012 }
1013 void test_read_of_access_disabled_region(int *ptr, u16 pkey)
1014 {
1015         int ptr_contents;
1016
1017         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
1018         rdpkru();
1019         pkey_access_deny(pkey);
1020         ptr_contents = read_ptr(ptr);
1021         dprintf1("*ptr: %d\n", ptr_contents);
1022         expected_pk_fault(pkey);
1023 }
1024 void test_write_of_write_disabled_region(int *ptr, u16 pkey)
1025 {
1026         dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
1027         pkey_write_deny(pkey);
1028         *ptr = __LINE__;
1029         expected_pk_fault(pkey);
1030 }
1031 void test_write_of_access_disabled_region(int *ptr, u16 pkey)
1032 {
1033         dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
1034         pkey_access_deny(pkey);
1035         *ptr = __LINE__;
1036         expected_pk_fault(pkey);
1037 }
1038 void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
1039 {
1040         int ret;
1041         int test_fd = get_test_read_fd();
1042
1043         dprintf1("disabling access to PKEY[%02d], "
1044                  "having kernel read() to buffer\n", pkey);
1045         pkey_access_deny(pkey);
1046         ret = read(test_fd, ptr, 1);
1047         dprintf1("read ret: %d\n", ret);
1048         pkey_assert(ret);
1049 }
1050 void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
1051 {
1052         int ret;
1053         int test_fd = get_test_read_fd();
1054
1055         pkey_write_deny(pkey);
1056         ret = read(test_fd, ptr, 100);
1057         dprintf1("read ret: %d\n", ret);
1058         if (ret < 0 && (DEBUG_LEVEL > 0))
1059                 perror("verbose read result (OK for this to be bad)");
1060         pkey_assert(ret);
1061 }
1062
1063 void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
1064 {
1065         int pipe_ret, vmsplice_ret;
1066         struct iovec iov;
1067         int pipe_fds[2];
1068
1069         pipe_ret = pipe(pipe_fds);
1070
1071         pkey_assert(pipe_ret == 0);
1072         dprintf1("disabling access to PKEY[%02d], "
1073                  "having kernel vmsplice from buffer\n", pkey);
1074         pkey_access_deny(pkey);
1075         iov.iov_base = ptr;
1076         iov.iov_len = PAGE_SIZE;
1077         vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
1078         dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
1079         pkey_assert(vmsplice_ret == -1);
1080
1081         close(pipe_fds[0]);
1082         close(pipe_fds[1]);
1083 }
1084
1085 void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
1086 {
1087         int ignored = 0xdada;
1088         int futex_ret;
1089         int some_int = __LINE__;
1090
1091         dprintf1("disabling write to PKEY[%02d], "
1092                  "doing futex gunk in buffer\n", pkey);
1093         *ptr = some_int;
1094         pkey_write_deny(pkey);
1095         futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
1096                         &ignored, ignored);
1097         if (DEBUG_LEVEL > 0)
1098                 perror("futex");
1099         dprintf1("futex() ret: %d\n", futex_ret);
1100 }
1101
1102 /* Assumes that all pkeys other than 'pkey' are unallocated */
1103 void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
1104 {
1105         int err;
1106         int i;
1107
1108         /* Note: 0 is the default pkey, so don't mess with it */
1109         for (i = 1; i < NR_PKEYS; i++) {
1110                 if (pkey == i)
1111                         continue;
1112
1113                 dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
1114                 err = sys_pkey_free(i);
1115                 pkey_assert(err);
1116
1117                 err = sys_pkey_free(i);
1118                 pkey_assert(err);
1119
1120                 err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
1121                 pkey_assert(err);
1122         }
1123 }
1124
1125 /* Assumes that all pkeys other than 'pkey' are unallocated */
1126 void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
1127 {
1128         int err;
1129         int bad_pkey = NR_PKEYS+99;
1130
1131         /* pass a known-invalid pkey in: */
1132         err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
1133         pkey_assert(err);
1134 }
1135
1136 void become_child(void)
1137 {
1138         pid_t forkret;
1139
1140         forkret = fork();
1141         pkey_assert(forkret >= 0);
1142         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
1143
1144         if (!forkret) {
1145                 /* in the child */
1146                 return;
1147         }
1148         exit(0);
1149 }
1150
1151 /* Assumes that all pkeys other than 'pkey' are unallocated */
1152 void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
1153 {
1154         int err;
1155         int allocated_pkeys[NR_PKEYS] = {0};
1156         int nr_allocated_pkeys = 0;
1157         int i;
1158
1159         for (i = 0; i < NR_PKEYS*3; i++) {
1160                 int new_pkey;
1161                 dprintf1("%s() alloc loop: %d\n", __func__, i);
1162                 new_pkey = alloc_pkey();
1163                 dprintf4("%s()::%d, err: %d pkru: 0x%x shadow: 0x%x\n", __func__,
1164                                 __LINE__, err, __rdpkru(), shadow_pkru);
1165                 rdpkru(); /* for shadow checking */
1166                 dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
1167                 if ((new_pkey == -1) && (errno == ENOSPC)) {
1168                         dprintf2("%s() failed to allocate pkey after %d tries\n",
1169                                 __func__, nr_allocated_pkeys);
1170                 } else {
1171                         /*
1172                          * Ensure the number of successes never
1173                          * exceeds the number of keys supported
1174                          * in the hardware.
1175                          */
1176                         pkey_assert(nr_allocated_pkeys < NR_PKEYS);
1177                         allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1178                 }
1179
1180                 /*
1181                  * Make sure that allocation state is properly
1182                  * preserved across fork().
1183                  */
1184                 if (i == NR_PKEYS*2)
1185                         become_child();
1186         }
1187
1188         dprintf3("%s()::%d\n", __func__, __LINE__);
1189
1190         /*
1191          * There are 16 pkeys supported in hardware.  Three are
1192          * allocated by the time we get here:
1193          *   1. The default key (0)
1194          *   2. One possibly consumed by an execute-only mapping.
1195          *   3. One allocated by the test code and passed in via
1196          *      'pkey' to this function.
1197          * Ensure that we can allocate at least another 13 (16-3).
1198          */
1199         pkey_assert(i >= NR_PKEYS-3);
1200
1201         for (i = 0; i < nr_allocated_pkeys; i++) {
1202                 err = sys_pkey_free(allocated_pkeys[i]);
1203                 pkey_assert(!err);
1204                 rdpkru(); /* for shadow checking */
1205         }
1206 }
1207
1208 /*
1209  * pkey 0 is special.  It is allocated by default, so you do not
1210  * have to call pkey_alloc() to use it first.  Make sure that it
1211  * is usable.
1212  */
1213 void test_mprotect_with_pkey_0(int *ptr, u16 pkey)
1214 {
1215         long size;
1216         int prot;
1217
1218         assert(pkey_last_malloc_record);
1219         size = pkey_last_malloc_record->size;
1220         /*
1221          * This is a bit of a hack.  But mprotect() requires
1222          * huge-page-aligned sizes when operating on hugetlbfs.
1223          * So, make sure that we use something that's a multiple
1224          * of a huge page when we can.
1225          */
1226         if (size >= HPAGE_SIZE)
1227                 size = HPAGE_SIZE;
1228         prot = pkey_last_malloc_record->prot;
1229
1230         /* Use pkey 0 */
1231         mprotect_pkey(ptr, size, prot, 0);
1232
1233         /* Make sure that we can set it back to the original pkey. */
1234         mprotect_pkey(ptr, size, prot, pkey);
1235 }
1236
1237 void test_ptrace_of_child(int *ptr, u16 pkey)
1238 {
1239         __attribute__((__unused__)) int peek_result;
1240         pid_t child_pid;
1241         void *ignored = 0;
1242         long ret;
1243         int status;
1244         /*
1245          * This is the "control" for our little expermient.  Make sure
1246          * we can always access it when ptracing.
1247          */
1248         int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
1249         int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
1250
1251         /*
1252          * Fork a child which is an exact copy of this process, of course.
1253          * That means we can do all of our tests via ptrace() and then plain
1254          * memory access and ensure they work differently.
1255          */
1256         child_pid = fork_lazy_child();
1257         dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
1258
1259         ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
1260         if (ret)
1261                 perror("attach");
1262         dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
1263         pkey_assert(ret != -1);
1264         ret = waitpid(child_pid, &status, WUNTRACED);
1265         if ((ret != child_pid) || !(WIFSTOPPED(status))) {
1266                 fprintf(stderr, "weird waitpid result %ld stat %x\n",
1267                                 ret, status);
1268                 pkey_assert(0);
1269         }
1270         dprintf2("waitpid ret: %ld\n", ret);
1271         dprintf2("waitpid status: %d\n", status);
1272
1273         pkey_access_deny(pkey);
1274         pkey_write_deny(pkey);
1275
1276         /* Write access, untested for now:
1277         ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
1278         pkey_assert(ret != -1);
1279         dprintf1("poke at %p: %ld\n", peek_at, ret);
1280         */
1281
1282         /*
1283          * Try to access the pkey-protected "ptr" via ptrace:
1284          */
1285         ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
1286         /* expect it to work, without an error: */
1287         pkey_assert(ret != -1);
1288         /* Now access from the current task, and expect an exception: */
1289         peek_result = read_ptr(ptr);
1290         expected_pk_fault(pkey);
1291
1292         /*
1293          * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
1294          */
1295         ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
1296         /* expect it to work, without an error: */
1297         pkey_assert(ret != -1);
1298         /* Now access from the current task, and expect NO exception: */
1299         peek_result = read_ptr(plain_ptr);
1300         do_not_expect_pk_fault("read plain pointer after ptrace");
1301
1302         ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
1303         pkey_assert(ret != -1);
1304
1305         ret = kill(child_pid, SIGKILL);
1306         pkey_assert(ret != -1);
1307
1308         wait(&status);
1309
1310         free(plain_ptr_unaligned);
1311 }
1312
1313 void *get_pointer_to_instructions(void)
1314 {
1315         void *p1;
1316
1317         p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
1318         dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
1319         /* lots_o_noops_around_write should be page-aligned already */
1320         assert(p1 == &lots_o_noops_around_write);
1321
1322         /* Point 'p1' at the *second* page of the function: */
1323         p1 += PAGE_SIZE;
1324
1325         /*
1326          * Try to ensure we fault this in on next touch to ensure
1327          * we get an instruction fault as opposed to a data one
1328          */
1329         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1330
1331         return p1;
1332 }
1333
1334 void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
1335 {
1336         void *p1;
1337         int scratch;
1338         int ptr_contents;
1339         int ret;
1340
1341         p1 = get_pointer_to_instructions();
1342         lots_o_noops_around_write(&scratch);
1343         ptr_contents = read_ptr(p1);
1344         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1345
1346         ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
1347         pkey_assert(!ret);
1348         pkey_access_deny(pkey);
1349
1350         dprintf2("pkru: %x\n", rdpkru());
1351
1352         /*
1353          * Make sure this is an *instruction* fault
1354          */
1355         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1356         lots_o_noops_around_write(&scratch);
1357         do_not_expect_pk_fault("executing on PROT_EXEC memory");
1358         ptr_contents = read_ptr(p1);
1359         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1360         expected_pk_fault(pkey);
1361 }
1362
1363 void test_implicit_mprotect_exec_only_memory(int *ptr, u16 pkey)
1364 {
1365         void *p1;
1366         int scratch;
1367         int ptr_contents;
1368         int ret;
1369
1370         dprintf1("%s() start\n", __func__);
1371
1372         p1 = get_pointer_to_instructions();
1373         lots_o_noops_around_write(&scratch);
1374         ptr_contents = read_ptr(p1);
1375         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1376
1377         /* Use a *normal* mprotect(), not mprotect_pkey(): */
1378         ret = mprotect(p1, PAGE_SIZE, PROT_EXEC);
1379         pkey_assert(!ret);
1380
1381         dprintf2("pkru: %x\n", rdpkru());
1382
1383         /* Make sure this is an *instruction* fault */
1384         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1385         lots_o_noops_around_write(&scratch);
1386         do_not_expect_pk_fault("executing on PROT_EXEC memory");
1387         ptr_contents = read_ptr(p1);
1388         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1389         expected_pk_fault(UNKNOWN_PKEY);
1390
1391         /*
1392          * Put the memory back to non-PROT_EXEC.  Should clear the
1393          * exec-only pkey off the VMA and allow it to be readable
1394          * again.  Go to PROT_NONE first to check for a kernel bug
1395          * that did not clear the pkey when doing PROT_NONE.
1396          */
1397         ret = mprotect(p1, PAGE_SIZE, PROT_NONE);
1398         pkey_assert(!ret);
1399
1400         ret = mprotect(p1, PAGE_SIZE, PROT_READ|PROT_EXEC);
1401         pkey_assert(!ret);
1402         ptr_contents = read_ptr(p1);
1403         do_not_expect_pk_fault("plain read on recently PROT_EXEC area");
1404 }
1405
1406 void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
1407 {
1408         int size = PAGE_SIZE;
1409         int sret;
1410
1411         if (cpu_has_pku()) {
1412                 dprintf1("SKIP: %s: no CPU support\n", __func__);
1413                 return;
1414         }
1415
1416         sret = syscall(SYS_mprotect_key, ptr, size, PROT_READ, pkey);
1417         pkey_assert(sret < 0);
1418 }
1419
1420 void (*pkey_tests[])(int *ptr, u16 pkey) = {
1421         test_read_of_write_disabled_region,
1422         test_read_of_access_disabled_region,
1423         test_write_of_write_disabled_region,
1424         test_write_of_access_disabled_region,
1425         test_kernel_write_of_access_disabled_region,
1426         test_kernel_write_of_write_disabled_region,
1427         test_kernel_gup_of_access_disabled_region,
1428         test_kernel_gup_write_to_write_disabled_region,
1429         test_executing_on_unreadable_memory,
1430         test_implicit_mprotect_exec_only_memory,
1431         test_mprotect_with_pkey_0,
1432         test_ptrace_of_child,
1433         test_pkey_syscalls_on_non_allocated_pkey,
1434         test_pkey_syscalls_bad_args,
1435         test_pkey_alloc_exhaust,
1436 };
1437
1438 void run_tests_once(void)
1439 {
1440         int *ptr;
1441         int prot = PROT_READ|PROT_WRITE;
1442
1443         for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
1444                 int pkey;
1445                 int orig_pkru_faults = pkru_faults;
1446
1447                 dprintf1("======================\n");
1448                 dprintf1("test %d preparing...\n", test_nr);
1449
1450                 tracing_on();
1451                 pkey = alloc_random_pkey();
1452                 dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
1453                 ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
1454                 dprintf1("test %d starting...\n", test_nr);
1455                 pkey_tests[test_nr](ptr, pkey);
1456                 dprintf1("freeing test memory: %p\n", ptr);
1457                 free_pkey_malloc(ptr);
1458                 sys_pkey_free(pkey);
1459
1460                 dprintf1("pkru_faults: %d\n", pkru_faults);
1461                 dprintf1("orig_pkru_faults: %d\n", orig_pkru_faults);
1462
1463                 tracing_off();
1464                 close_test_fds();
1465
1466                 printf("test %2d PASSED (iteration %d)\n", test_nr, iteration_nr);
1467                 dprintf1("======================\n\n");
1468         }
1469         iteration_nr++;
1470 }
1471
1472 void pkey_setup_shadow(void)
1473 {
1474         shadow_pkru = __rdpkru();
1475 }
1476
1477 int main(void)
1478 {
1479         int nr_iterations = 22;
1480
1481         srand((unsigned int)time(NULL));
1482
1483         setup_handlers();
1484
1485         printf("has pku: %d\n", cpu_has_pku());
1486
1487         if (!cpu_has_pku()) {
1488                 int size = PAGE_SIZE;
1489                 int *ptr;
1490
1491                 printf("running PKEY tests for unsupported CPU/OS\n");
1492
1493                 ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1494                 assert(ptr != (void *)-1);
1495                 test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
1496                 exit(0);
1497         }
1498
1499         pkey_setup_shadow();
1500         printf("startup pkru: %x\n", rdpkru());
1501         setup_hugetlbfs();
1502
1503         while (nr_iterations-- > 0)
1504                 run_tests_once();
1505
1506         printf("done (all tests OK)\n");
1507         return 0;
1508 }