@@ -930,7 +930,7 @@ cdef class Basic(object):
930930 if (len (f) != 1 ):
931931 raise RuntimeError (" Variable w.r.t should be given" )
932932 return self ._diff(f.pop())
933- return diff (self , * args)
933+ return _diff (self , * args)
934934
935935 def subs_dict (Basic self not None , *args ):
936936 warnings.warn(" subs_dict() is deprecated. Use subs() instead" , DeprecationWarning )
@@ -3687,7 +3687,7 @@ cdef class DenseMatrixBase(MatrixBase):
36873687 return R
36883688
36893689 def diff (self , *args ):
3690- return diff (self , * args)
3690+ return _diff (self , * args)
36913691
36923692 # TODO: implement this in C++
36933693 def subs (self , *args ):
@@ -4063,15 +4063,23 @@ def module_cleanup():
40634063import atexit
40644064atexit.register(module_cleanup)
40654065
4066+
40664067def diff (expr , *args ):
4067- cdef Basic ex = sympify(expr)
4068+ if isinstance (expr, MatrixBase):
4069+ # Don't sympify matrices so that mutable matrices
4070+ # return mutable matrices
4071+ return _diff(expr, * args)
4072+ return _diff(sympify(expr), * args)
4073+
4074+
4075+ def _diff (expr , *args ):
40684076 cdef Basic prev
40694077 cdef Basic b
40704078 cdef size_t i
40714079 cdef size_t length = len (args)
40724080
40734081 if not length:
4074- return ex
4082+ return expr
40754083
40764084 cdef size_t l = 0
40774085 cdef Basic cur_arg, next_arg
@@ -4083,20 +4091,20 @@ def diff(expr, *args):
40834091
40844092 if l + 1 == length:
40854093 # No next argument, differentiate with no integer argument
4086- return ex ._diff(cur_arg)
4094+ return expr ._diff(cur_arg)
40874095
40884096 next_arg = sympify(args[l + 1 ])
40894097 # Check if the next arg was derivative order
40904098 if isinstance (next_arg, Integer):
40914099 i = int (next_arg)
40924100 for _ in range (i):
4093- ex = ex ._diff(cur_arg)
4101+ expr = expr ._diff(cur_arg)
40944102 l += 2
40954103 if l == length:
4096- return ex
4104+ return expr
40974105 cur_arg = sympify(args[l])
40984106 else :
4099- ex = ex ._diff(cur_arg)
4107+ expr = expr ._diff(cur_arg)
41004108 l += 1
41014109 cur_arg = next_arg
41024110
0 commit comments