1// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2/*
3 * Copyright (c) 2022-2023 Fujitsu Ltd. All rights reserved.
4 */
5
6#include <linux/hmm.h>
7#include <linux/libnvdimm.h>
8
9#include <rdma/ib_umem_odp.h>
10
11#include "rxe.h"
12
13static bool rxe_ib_invalidate_range(struct mmu_interval_notifier *mni,
14 const struct mmu_notifier_range *range,
15 unsigned long cur_seq)
16{
17 struct ib_umem_odp *umem_odp =
18 container_of(mni, struct ib_umem_odp, notifier);
19 unsigned long start, end;
20
21 if (!mmu_notifier_range_blockable(range))
22 return false;
23
24 mutex_lock(&umem_odp->umem_mutex);
25 mmu_interval_set_seq(interval_sub: mni, cur_seq);
26
27 start = max_t(u64, ib_umem_start(umem_odp), range->start);
28 end = min_t(u64, ib_umem_end(umem_odp), range->end);
29
30 /* update umem_odp->map.pfn_list */
31 ib_umem_odp_unmap_dma_pages(umem_odp, start_offset: start, bound: end);
32
33 mutex_unlock(lock: &umem_odp->umem_mutex);
34 return true;
35}
36
37const struct mmu_interval_notifier_ops rxe_mn_ops = {
38 .invalidate = rxe_ib_invalidate_range,
39};
40
41#define RXE_PAGEFAULT_DEFAULT 0
42#define RXE_PAGEFAULT_RDONLY BIT(0)
43#define RXE_PAGEFAULT_SNAPSHOT BIT(1)
44static int rxe_odp_do_pagefault_and_lock(struct rxe_mr *mr, u64 user_va, int bcnt, u32 flags)
45{
46 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
47 bool fault = !(flags & RXE_PAGEFAULT_SNAPSHOT);
48 u64 access_mask = 0;
49 int np;
50
51 if (umem_odp->umem.writable && !(flags & RXE_PAGEFAULT_RDONLY))
52 access_mask |= HMM_PFN_WRITE;
53
54 /*
55 * ib_umem_odp_map_dma_and_lock() locks umem_mutex on success.
56 * Callers must release the lock later to let invalidation handler
57 * do its work again.
58 */
59 np = ib_umem_odp_map_dma_and_lock(umem_odp, start_offset: user_va, bcnt,
60 access_mask, fault);
61 return np;
62}
63
64static int rxe_odp_init_pages(struct rxe_mr *mr)
65{
66 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
67 int ret;
68
69 ret = rxe_odp_do_pagefault_and_lock(mr, user_va: mr->umem->address,
70 bcnt: mr->umem->length,
71 RXE_PAGEFAULT_SNAPSHOT);
72
73 if (ret >= 0)
74 mutex_unlock(lock: &umem_odp->umem_mutex);
75
76 return ret >= 0 ? 0 : ret;
77}
78
79int rxe_odp_mr_init_user(struct rxe_dev *rxe, u64 start, u64 length,
80 u64 iova, int access_flags, struct rxe_mr *mr)
81{
82 struct ib_umem_odp *umem_odp;
83 int err;
84
85 if (!IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
86 return -EOPNOTSUPP;
87
88 rxe_mr_init(access: access_flags, mr);
89
90 if (!start && length == U64_MAX) {
91 if (iova != 0)
92 return -EINVAL;
93 if (!(rxe->attr.odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
94 return -EINVAL;
95
96 /* Never reach here, for implicit ODP is not implemented. */
97 }
98
99 umem_odp = ib_umem_odp_get(device: &rxe->ib_dev, addr: start, size: length, access: access_flags,
100 ops: &rxe_mn_ops);
101 if (IS_ERR(ptr: umem_odp)) {
102 rxe_dbg_mr(mr, "Unable to create umem_odp err = %d\n",
103 (int)PTR_ERR(umem_odp));
104 return PTR_ERR(ptr: umem_odp);
105 }
106
107 umem_odp->private = mr;
108
109 mr->umem = &umem_odp->umem;
110 mr->access = access_flags;
111 mr->ibmr.length = length;
112 mr->ibmr.iova = iova;
113 mr->page_offset = ib_umem_offset(umem: &umem_odp->umem);
114
115 err = rxe_odp_init_pages(mr);
116 if (err) {
117 ib_umem_odp_release(umem_odp);
118 return err;
119 }
120
121 mr->state = RXE_MR_STATE_VALID;
122 mr->ibmr.type = IB_MR_TYPE_USER;
123
124 return err;
125}
126
127static inline bool rxe_check_pagefault(struct ib_umem_odp *umem_odp, u64 iova,
128 int length)
129{
130 bool need_fault = false;
131 u64 addr;
132 int idx;
133
134 addr = iova & (~(BIT(umem_odp->page_shift) - 1));
135
136 /* Skim through all pages that are to be accessed. */
137 while (addr < iova + length) {
138 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
139
140 if (!(umem_odp->map.pfn_list[idx] & HMM_PFN_VALID)) {
141 need_fault = true;
142 break;
143 }
144
145 addr += BIT(umem_odp->page_shift);
146 }
147 return need_fault;
148}
149
150static unsigned long rxe_odp_iova_to_index(struct ib_umem_odp *umem_odp, u64 iova)
151{
152 return (iova - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
153}
154
155static unsigned long rxe_odp_iova_to_page_offset(struct ib_umem_odp *umem_odp, u64 iova)
156{
157 return iova & (BIT(umem_odp->page_shift) - 1);
158}
159
160static int rxe_odp_map_range_and_lock(struct rxe_mr *mr, u64 iova, int length, u32 flags)
161{
162 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
163 bool need_fault;
164 int err;
165
166 if (unlikely(length < 1))
167 return -EINVAL;
168
169 mutex_lock(&umem_odp->umem_mutex);
170
171 need_fault = rxe_check_pagefault(umem_odp, iova, length);
172 if (need_fault) {
173 mutex_unlock(lock: &umem_odp->umem_mutex);
174
175 /* umem_mutex is locked on success. */
176 err = rxe_odp_do_pagefault_and_lock(mr, user_va: iova, bcnt: length,
177 flags);
178 if (err < 0)
179 return err;
180
181 need_fault = rxe_check_pagefault(umem_odp, iova, length);
182 if (need_fault) {
183 mutex_unlock(lock: &umem_odp->umem_mutex);
184 return -EFAULT;
185 }
186 }
187
188 return 0;
189}
190
191static int __rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr,
192 int length, enum rxe_mr_copy_dir dir)
193{
194 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
195 struct page *page;
196 int idx, bytes;
197 size_t offset;
198 u8 *user_va;
199
200 idx = rxe_odp_iova_to_index(umem_odp, iova);
201 offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
202
203 while (length > 0) {
204 u8 *src, *dest;
205
206 page = hmm_pfn_to_page(hmm_pfn: umem_odp->map.pfn_list[idx]);
207 user_va = kmap_local_page(page);
208
209 src = (dir == RXE_TO_MR_OBJ) ? addr : user_va;
210 dest = (dir == RXE_TO_MR_OBJ) ? user_va : addr;
211
212 bytes = BIT(umem_odp->page_shift) - offset;
213 if (bytes > length)
214 bytes = length;
215
216 memcpy(dest, src, bytes);
217 kunmap_local(user_va);
218
219 length -= bytes;
220 idx++;
221 offset = 0;
222 }
223
224 return 0;
225}
226
227int rxe_odp_mr_copy(struct rxe_mr *mr, u64 iova, void *addr, int length,
228 enum rxe_mr_copy_dir dir)
229{
230 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
231 u32 flags = RXE_PAGEFAULT_DEFAULT;
232 int err;
233
234 if (length == 0)
235 return 0;
236
237 if (unlikely(!mr->umem->is_odp))
238 return -EOPNOTSUPP;
239
240 switch (dir) {
241 case RXE_TO_MR_OBJ:
242 break;
243
244 case RXE_FROM_MR_OBJ:
245 flags |= RXE_PAGEFAULT_RDONLY;
246 break;
247
248 default:
249 return -EINVAL;
250 }
251
252 err = rxe_odp_map_range_and_lock(mr, iova, length, flags);
253 if (err)
254 return err;
255
256 err = __rxe_odp_mr_copy(mr, iova, addr, length, dir);
257
258 mutex_unlock(lock: &umem_odp->umem_mutex);
259
260 return err;
261}
262
263static enum resp_states rxe_odp_do_atomic_op(struct rxe_mr *mr, u64 iova,
264 int opcode, u64 compare,
265 u64 swap_add, u64 *orig_val)
266{
267 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
268 unsigned int page_offset;
269 struct page *page;
270 unsigned int idx;
271 u64 value;
272 u64 *va;
273 int err;
274
275 if (unlikely(mr->state != RXE_MR_STATE_VALID)) {
276 rxe_dbg_mr(mr, "mr not in valid state\n");
277 return RESPST_ERR_RKEY_VIOLATION;
278 }
279
280 err = mr_check_range(mr, iova, length: sizeof(value));
281 if (err) {
282 rxe_dbg_mr(mr, "iova out of range\n");
283 return RESPST_ERR_RKEY_VIOLATION;
284 }
285
286 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
287 if (unlikely(page_offset & 0x7)) {
288 rxe_dbg_mr(mr, "iova not aligned\n");
289 return RESPST_ERR_MISALIGNED_ATOMIC;
290 }
291
292 idx = rxe_odp_iova_to_index(umem_odp, iova);
293 page = hmm_pfn_to_page(hmm_pfn: umem_odp->map.pfn_list[idx]);
294
295 va = kmap_local_page(page);
296
297 spin_lock_bh(lock: &atomic_ops_lock);
298 value = *orig_val = va[page_offset >> 3];
299
300 if (opcode == IB_OPCODE_RC_COMPARE_SWAP) {
301 if (value == compare)
302 va[page_offset >> 3] = swap_add;
303 } else {
304 value += swap_add;
305 va[page_offset >> 3] = value;
306 }
307 spin_unlock_bh(lock: &atomic_ops_lock);
308
309 kunmap_local(va);
310
311 return RESPST_NONE;
312}
313
314enum resp_states rxe_odp_atomic_op(struct rxe_mr *mr, u64 iova, int opcode,
315 u64 compare, u64 swap_add, u64 *orig_val)
316{
317 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
318 int err;
319
320 err = rxe_odp_map_range_and_lock(mr, iova, length: sizeof(char),
321 RXE_PAGEFAULT_DEFAULT);
322 if (err < 0)
323 return RESPST_ERR_RKEY_VIOLATION;
324
325 err = rxe_odp_do_atomic_op(mr, iova, opcode, compare, swap_add,
326 orig_val);
327 mutex_unlock(lock: &umem_odp->umem_mutex);
328
329 return err;
330}
331
332int rxe_odp_flush_pmem_iova(struct rxe_mr *mr, u64 iova,
333 unsigned int length)
334{
335 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
336 unsigned int page_offset;
337 unsigned long index;
338 struct page *page;
339 unsigned int bytes;
340 int err;
341 u8 *va;
342
343 err = rxe_odp_map_range_and_lock(mr, iova, length,
344 RXE_PAGEFAULT_DEFAULT);
345 if (err)
346 return err;
347
348 while (length > 0) {
349 index = rxe_odp_iova_to_index(umem_odp, iova);
350 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
351
352 page = hmm_pfn_to_page(hmm_pfn: umem_odp->map.pfn_list[index]);
353
354 bytes = min_t(unsigned int, length,
355 mr_page_size(mr) - page_offset);
356
357 va = kmap_local_page(page);
358 arch_wb_cache_pmem(addr: va + page_offset, size: bytes);
359 kunmap_local(va);
360
361 length -= bytes;
362 iova += bytes;
363 }
364
365 mutex_unlock(lock: &umem_odp->umem_mutex);
366
367 return 0;
368}
369
370enum resp_states rxe_odp_do_atomic_write(struct rxe_mr *mr, u64 iova, u64 value)
371{
372 struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem: mr->umem);
373 unsigned int page_offset;
374 unsigned long index;
375 struct page *page;
376 int err;
377 u64 *va;
378
379 /* See IBA oA19-28 */
380 err = mr_check_range(mr, iova, length: sizeof(value));
381 if (unlikely(err)) {
382 rxe_dbg_mr(mr, "iova out of range\n");
383 return RESPST_ERR_RKEY_VIOLATION;
384 }
385
386 err = rxe_odp_map_range_and_lock(mr, iova, length: sizeof(value),
387 RXE_PAGEFAULT_DEFAULT);
388 if (err)
389 return RESPST_ERR_RKEY_VIOLATION;
390
391 page_offset = rxe_odp_iova_to_page_offset(umem_odp, iova);
392 /* See IBA A19.4.2 */
393 if (unlikely(page_offset & 0x7)) {
394 mutex_unlock(lock: &umem_odp->umem_mutex);
395 rxe_dbg_mr(mr, "misaligned address\n");
396 return RESPST_ERR_MISALIGNED_ATOMIC;
397 }
398
399 index = rxe_odp_iova_to_index(umem_odp, iova);
400 page = hmm_pfn_to_page(hmm_pfn: umem_odp->map.pfn_list[index]);
401
402 va = kmap_local_page(page);
403 /* Do atomic write after all prior operations have completed */
404 smp_store_release(&va[page_offset >> 3], value);
405 kunmap_local(va);
406
407 mutex_unlock(lock: &umem_odp->umem_mutex);
408
409 return RESPST_NONE;
410}
411
412struct prefetch_mr_work {
413 struct work_struct work;
414 u32 pf_flags;
415 u32 num_sge;
416 struct {
417 u64 io_virt;
418 struct rxe_mr *mr;
419 size_t length;
420 } frags[];
421};
422
423static void rxe_ib_prefetch_mr_work(struct work_struct *w)
424{
425 struct prefetch_mr_work *work =
426 container_of(w, struct prefetch_mr_work, work);
427 int ret;
428 u32 i;
429
430 /*
431 * We rely on IB/core that work is executed
432 * if we have num_sge != 0 only.
433 */
434 WARN_ON(!work->num_sge);
435 for (i = 0; i < work->num_sge; ++i) {
436 struct ib_umem_odp *umem_odp;
437
438 ret = rxe_odp_do_pagefault_and_lock(mr: work->frags[i].mr,
439 user_va: work->frags[i].io_virt,
440 bcnt: work->frags[i].length,
441 flags: work->pf_flags);
442 if (ret < 0) {
443 rxe_dbg_mr(work->frags[i].mr,
444 "failed to prefetch the mr\n");
445 goto deref;
446 }
447
448 umem_odp = to_ib_umem_odp(umem: work->frags[i].mr->umem);
449 mutex_unlock(lock: &umem_odp->umem_mutex);
450
451deref:
452 rxe_put(work->frags[i].mr);
453 }
454
455 kvfree(addr: work);
456}
457
458static int rxe_ib_prefetch_sg_list(struct ib_pd *ibpd,
459 enum ib_uverbs_advise_mr_advice advice,
460 u32 pf_flags, struct ib_sge *sg_list,
461 u32 num_sge)
462{
463 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
464 int ret = 0;
465 u32 i;
466
467 for (i = 0; i < num_sge; ++i) {
468 struct rxe_mr *mr;
469 struct ib_umem_odp *umem_odp;
470
471 mr = lookup_mr(pd, access: IB_ACCESS_LOCAL_WRITE,
472 key: sg_list[i].lkey, type: RXE_LOOKUP_LOCAL);
473
474 if (!mr) {
475 rxe_dbg_pd(pd, "mr with lkey %x not found\n",
476 sg_list[i].lkey);
477 return -EINVAL;
478 }
479
480 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
481 !mr->umem->writable) {
482 rxe_dbg_mr(mr, "missing write permission\n");
483 rxe_put(mr);
484 return -EPERM;
485 }
486
487 ret = rxe_odp_do_pagefault_and_lock(
488 mr, user_va: sg_list[i].addr, bcnt: sg_list[i].length, flags: pf_flags);
489 if (ret < 0) {
490 rxe_dbg_mr(mr, "failed to prefetch the mr\n");
491 rxe_put(mr);
492 return ret;
493 }
494
495 umem_odp = to_ib_umem_odp(umem: mr->umem);
496 mutex_unlock(lock: &umem_odp->umem_mutex);
497
498 rxe_put(mr);
499 }
500
501 return 0;
502}
503
504static int rxe_ib_advise_mr_prefetch(struct ib_pd *ibpd,
505 enum ib_uverbs_advise_mr_advice advice,
506 u32 flags, struct ib_sge *sg_list,
507 u32 num_sge)
508{
509 struct rxe_pd *pd = container_of(ibpd, struct rxe_pd, ibpd);
510 u32 pf_flags = RXE_PAGEFAULT_DEFAULT;
511 struct prefetch_mr_work *work;
512 struct rxe_mr *mr;
513 u32 i;
514
515 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
516 pf_flags |= RXE_PAGEFAULT_RDONLY;
517
518 if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
519 pf_flags |= RXE_PAGEFAULT_SNAPSHOT;
520
521 /* Synchronous call */
522 if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
523 return rxe_ib_prefetch_sg_list(ibpd, advice, pf_flags, sg_list,
524 num_sge);
525
526 /* Asynchronous call is "best-effort" and allowed to fail */
527 work = kvzalloc(struct_size(work, frags, num_sge), GFP_KERNEL);
528 if (!work)
529 return -ENOMEM;
530
531 INIT_WORK(&work->work, rxe_ib_prefetch_mr_work);
532 work->pf_flags = pf_flags;
533 work->num_sge = num_sge;
534
535 for (i = 0; i < num_sge; ++i) {
536 /* Takes a reference, which will be released in the queued work */
537 mr = lookup_mr(pd, access: IB_ACCESS_LOCAL_WRITE,
538 key: sg_list[i].lkey, type: RXE_LOOKUP_LOCAL);
539 if (!mr) {
540 mr = ERR_PTR(error: -EINVAL);
541 goto err;
542 }
543
544 work->frags[i].io_virt = sg_list[i].addr;
545 work->frags[i].length = sg_list[i].length;
546 work->frags[i].mr = mr;
547 }
548
549 queue_work(wq: system_unbound_wq, work: &work->work);
550
551 return 0;
552
553 err:
554 /* rollback reference counts for the invalid request */
555 while (i > 0) {
556 i--;
557 rxe_put(work->frags[i].mr);
558 }
559
560 kvfree(addr: work);
561
562 return PTR_ERR(ptr: mr);
563}
564
565int rxe_ib_advise_mr(struct ib_pd *ibpd,
566 enum ib_uverbs_advise_mr_advice advice,
567 u32 flags,
568 struct ib_sge *sg_list,
569 u32 num_sge,
570 struct uverbs_attr_bundle *attrs)
571{
572 if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
573 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
574 advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
575 return -EOPNOTSUPP;
576
577 return rxe_ib_advise_mr_prefetch(ibpd, advice, flags,
578 sg_list, num_sge);
579}
580

source code of linux/drivers/infiniband/sw/rxe/rxe_odp.c