Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
febe7090
Commit
febe7090
authored
Jun 10, 2021
by
Philipp Arras
Browse files
Add explanation how to add nonlinearities to NIFTy
parent
2236cf7b
Pipeline
#103346
passed with stages
in 14 minutes and 10 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Sidebyside
.gitlabci.yml
View file @
febe7090
...
...
@@ 147,3 +147,8 @@ run_visual_vi:
stage
:
demo_runs
script
:

python3 demos/variational_inference_visualized.py
run_nonlinearity_guide
:
stage
:
demo_runs
script
:

python3 demos/custom_nonlinearities.py
demos/custom_nonlinearities.py
0 → 100644
View file @
febe7090
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 MaxPlanckSociety
# Author: Philipp Arras
import
nifty7
as
ift
import
numpy
as
np
# In NIFTy, users can add handcrafted pointwise nonlinearities that are then
# available for `Field`, `MultiField`, `Linearization` and `Operator`. This
# guide shows an example how this is done.
# Suppose, we would like to use the function f(x) = x*exp(x) pointwise in an
# operator chain. This function is called "myptw" in the following. We
# introduce this function to NIFTy by implementing two functions.
# First, one that takes a `numpy.ndarray` as an input, applies the pointwise
# mapping and returns the result as a `numpy.ndarray` (of the same shape).
# Second, a function that takes an `numpy.ndarray` as an input and returns two
# `numpy.ndarray`s: the application of the nonlinearity (same as before) and
# the derivative.
def
func
(
x
):
return
x
*
np
.
exp
(
x
)
def
func_and_derv
(
x
):
expx
=
np
.
exp
(
x
)
return
x
*
expx
,
(
1
+
x
)
*
expx
# These two functions are then added to the NIFTyinternal dictionary that
# contains all implemented pointwise nonlinearities.
ift
.
pointwise
.
ptw_dict
[
"myptw"
]
=
func
,
func_and_derv
# This allows us to apply this nonlinearity on `Field`s, ...
dom
=
ift
.
UnstructuredDomain
(
10
)
fld
=
ift
.
from_random
(
dom
)
fld
=
ift
.
full
(
dom
,
2.
)
a
=
fld
.
ptw
(
"myptw"
)
b
=
ift
.
makeField
(
dom
,
func
(
fld
.
val
))
ift
.
extra
.
assert_allclose
(
a
,
b
)
# `MultiField`s, ...
mdom
=
ift
.
makeDomain
({
"bar"
:
ift
.
UnstructuredDomain
(
10
)})
mfld
=
ift
.
from_random
(
mdom
)
a
=
mfld
.
ptw
(
"myptw"
)
b
=
ift
.
makeField
(
mdom
,
{
"bar"
:
func
(
mfld
[
"bar"
].
val
)})
ift
.
extra
.
assert_allclose
(
a
,
b
)
# Linearizations (including the Jacobian), ...
# (Value)
lin
=
ift
.
Linearization
.
make_var
(
fld
)
a
=
lin
.
ptw
(
"myptw"
).
val
b
=
ift
.
makeField
(
dom
,
func
(
fld
.
val
))
ift
.
extra
.
assert_allclose
(
a
,
b
)
# (Jacobian)
op_a
=
lin
.
ptw
(
"myptw"
).
jac
op_b
=
ift
.
makeOp
(
ift
.
makeField
(
dom
,
func_and_derv
(
fld
.
val
)[
1
]))
testing_vector
=
ift
.
from_random
(
dom
)
ift
.
extra
.
assert_allclose
(
op_a
(
testing_vector
),
op_b
(
testing_vector
))
# and `Operator`s.
op
=
ift
.
FieldAdapter
(
dom
,
"foo"
).
ptw
(
"myptw"
)
# We check that the gradient has been implemented correctly by comparing it to
# an approximation to the gradient by finite differences.
def
check
(
func_name
,
eps
=
1e7
):
pos
=
ift
.
from_random
(
ift
.
UnstructuredDomain
(
10
))
var0
=
ift
.
Linearization
.
make_var
(
pos
)
var1
=
ift
.
Linearization
.
make_var
(
pos
+
eps
)
df0
=
(
var1
.
ptw
(
func_name
).
val

var0
.
ptw
(
func_name
).
val
)
/
eps
df1
=
var0
.
ptw
(
func_name
).
jac
(
ift
.
full
(
lin
.
domain
,
1.
))
# rtol depends on how nonlinear the function is
ift
.
extra
.
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
eps
)
check
(
"myptw"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment